This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new b1f8ec835 [CELEBORN-1351] Introduce SSLFactory and enable TLS support
b1f8ec835 is described below

commit b1f8ec83575d556cd5d9da1620b7fc75b4dcee60
Author: Mridul Muralidharan <mridulatgmail.com>
AuthorDate: Mon Apr 8 10:42:29 2024 +0800

    [CELEBORN-1351] Introduce SSLFactory and enable TLS support
    
    ### What changes were proposed in this pull request?
    
    Add SSLFactory, and wire up TLS support with rest of Celeborn to enable 
secure over the wire communication.
    
    ### Why are the changes needed?
    Add support for TLS to secure wire communication.
    This is the last PR to add basic support for TLS.
    There will be a follow up for CELEBORN-1356 and documentation ofcourse !
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, completes basic support for TLS in Celeborn.
    
    ### How was this patch tested?
    Existing tests, augmented with additional unit tests.
    
    Closes #2438 from mridulm/add-sslfactory-and-related-changes.
    
    Authored-by: Mridul Muralidharan <mridulatgmail.com>
    Signed-off-by: SteNicholas <[email protected]>
---
 LICENSE                                            |   1 +
 .../apache/celeborn/client/ShuffleClientImpl.java  |  10 +-
 .../celeborn/common/network/TransportContext.java  |  83 ++++-
 .../network/client/TransportClientFactory.java     |  33 +-
 .../common/network/server/TransportServer.java     |   2 +-
 .../celeborn/common/network/ssl/SSLFactory.java    | 393 +++++++++++++++++++++
 .../celeborn/common/rpc/netty/NettyRpcEnv.scala    |   6 +-
 .../common/network/RpcIntegrationSuiteJ.java       |  11 +-
 .../common/network/SSLRpcIntegrationSuiteJ.java    |  49 +++
 .../network/SSLTransportClientFactorySuiteJ.java   |  49 +++
 .../network/TransportClientFactorySuiteJ.java      |  25 +-
 .../celeborn/common/network/sasl/SaslTestBase.java |   3 +
 .../common/network/ssl/SslConnectivitySuiteJ.java  | 341 ++++++++++++++++++
 .../common/network/ssl/SslSampleConfigs.java       |   1 -
 .../common/rpc/netty/SSLNettyRpcEnvSuite.scala     |  44 +++
 .../celeborn/service/deploy/worker/Worker.scala    |  18 +-
 .../network/RequestTimeoutIntegrationSuiteJ.java   |  19 +-
 .../SSLRequestTimeoutIntegrationSuiteJ.java        |  49 +++
 .../storage/ChunkFetchIntegrationSuiteJ.java       |  20 +-
 .../storage/ReducePartitionDataWriterSuiteJ.java   |  24 +-
 .../storage/SSLChunkFetchIntegrationSuiteJ.java    |  49 +++
 .../SSLReducePartitionDataWriterSuiteJ.java        |  45 +++
 22 files changed, 1233 insertions(+), 42 deletions(-)

diff --git a/LICENSE b/LICENSE
index d5f68e4d7..30b865ffc 100644
--- a/LICENSE
+++ b/LICENSE
@@ -215,6 +215,7 @@ Apache Spark
 
./common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
 
./common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
 
./common/src/main/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManager.java
+./common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
 ./common/src/main/java/org/apache/celeborn/common/network/util/NettyLogger.java
 ./common/src/main/java/org/apache/celeborn/common/unsafe/Platform.java
 ./common/src/main/java/org/apache/celeborn/common/util/JavaUtils.java
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 0494d73e9..6aeca765c 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -88,6 +88,7 @@ public class ShuffleClientImpl extends ShuffleClient {
 
   protected RpcEndpointRef lifecycleManagerRef;
 
+  private TransportContext transportContext;
   protected TransportClientFactory dataClientFactory;
 
   protected final int BATCH_HEADER_SIZE = 4 * 4;
@@ -211,12 +212,12 @@ public class ShuffleClientImpl extends ShuffleClient {
     if (dataClientFactory != null) {
       return;
     }
-    TransportContext context =
+    this.transportContext =
         new TransportContext(
             dataTransportConf, new BaseMessageHandler(), 
conf.clientCloseIdleConnections());
     if (!authEnabled) {
       logger.info("Initializing data client factory for {}.", appUniqueId);
-      dataClientFactory = context.createClientFactory();
+      dataClientFactory = transportContext.createClientFactory();
     } else if (lifecycleManagerRef != null) {
       PbApplicationMetaRequest pbApplicationMetaRequest =
           PbApplicationMetaRequest.newBuilder().setAppId(appUniqueId).build();
@@ -232,7 +233,7 @@ public class ShuffleClientImpl extends ShuffleClient {
               dataTransportConf,
               appUniqueId,
               new SaslCredentials(appUniqueId, 
pbApplicationMeta.getSecret())));
-      dataClientFactory = context.createClientFactory(bootstraps);
+      dataClientFactory = transportContext.createClientFactory(bootstraps);
     }
   }
 
@@ -1742,6 +1743,9 @@ public class ShuffleClientImpl extends ShuffleClient {
     if (null != dataClientFactory) {
       dataClientFactory.close();
     }
+    if (null != transportContext) {
+      transportContext.close();
+    }
     if (null != pushDataRetryPool) {
       pushDataRetryPool.shutdown();
     }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java 
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
index ec0d1fd87..58366a087 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
@@ -17,14 +17,19 @@
 
 package org.apache.celeborn.common.network;
 
+import java.io.Closeable;
 import java.util.Collections;
 import java.util.List;
 
+import javax.annotation.Nullable;
+
 import io.netty.channel.Channel;
 import io.netty.channel.ChannelDuplexHandler;
 import io.netty.channel.ChannelInboundHandlerAdapter;
 import io.netty.channel.ChannelPipeline;
 import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.handler.stream.ChunkedWriteHandler;
 import io.netty.handler.timeout.IdleStateHandler;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -35,7 +40,9 @@ import 
org.apache.celeborn.common.network.client.TransportClientBootstrap;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
 import org.apache.celeborn.common.network.client.TransportResponseHandler;
 import org.apache.celeborn.common.network.protocol.MessageEncoder;
+import org.apache.celeborn.common.network.protocol.SslMessageEncoder;
 import org.apache.celeborn.common.network.server.*;
+import org.apache.celeborn.common.network.ssl.SSLFactory;
 import org.apache.celeborn.common.network.util.FrameDecoder;
 import org.apache.celeborn.common.network.util.NettyLogger;
 import org.apache.celeborn.common.network.util.TransportConf;
@@ -54,7 +61,7 @@ import 
org.apache.celeborn.common.network.util.TransportFrameDecoder;
  * channel. As each TransportChannelHandler contains a TransportClient, this 
enables server
  * processes to send messages back to the client on an existing channel.
  */
-public class TransportContext {
+public class TransportContext implements Closeable {
   private static final Logger logger = 
LoggerFactory.getLogger(TransportContext.class);
 
   private static final NettyLogger nettyLogger = new NettyLogger();
@@ -62,10 +69,13 @@ public class TransportContext {
   private final BaseMessageHandler msgHandler;
   private final ChannelDuplexHandler channelsLimiter;
   private final boolean closeIdleConnections;
+  // Non-null if SSL is enabled, null otherwise.
+  @Nullable private final SSLFactory sslFactory;
   private final boolean enableHeartbeat;
   private final AbstractSource source;
 
   private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
+  private static final SslMessageEncoder SSL_ENCODER = 
SslMessageEncoder.INSTANCE;
 
   public TransportContext(
       TransportConf conf,
@@ -77,6 +87,7 @@ public class TransportContext {
     this.conf = conf;
     this.msgHandler = msgHandler;
     this.closeIdleConnections = closeIdleConnections;
+    this.sslFactory = createSslFactory();
     this.channelsLimiter = channelsLimiter;
     this.enableHeartbeat = enableHeartbeat;
     this.source = source;
@@ -135,31 +146,53 @@ public class TransportContext {
     return createServer(null, 0, Collections.emptyList());
   }
 
+  public boolean sslEncryptionEnabled() {
+    return this.sslFactory != null;
+  }
+
   public TransportChannelHandler initializePipeline(
-      SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
-    return initializePipeline(channel, decoder, msgHandler);
+      SocketChannel channel, ChannelInboundHandlerAdapter decoder, boolean 
isClient) {
+    return initializePipeline(channel, decoder, msgHandler, isClient);
   }
 
   public TransportChannelHandler initializePipeline(
-      SocketChannel channel, BaseMessageHandler resolvedMsgHandler) {
-    return initializePipeline(channel, new TransportFrameDecoder(), 
resolvedMsgHandler);
+      SocketChannel channel, BaseMessageHandler resolvedMsgHandler, boolean 
isClient) {
+    return initializePipeline(channel, new TransportFrameDecoder(), 
resolvedMsgHandler, isClient);
   }
 
   public TransportChannelHandler initializePipeline(
       SocketChannel channel,
       ChannelInboundHandlerAdapter decoder,
-      BaseMessageHandler resolvedMsgHandler) {
+      BaseMessageHandler resolvedMsgHandler,
+      boolean isClient) {
     try {
       ChannelPipeline pipeline = channel.pipeline();
       if (nettyLogger.getLoggingHandler() != null) {
         pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler());
       }
+      if (sslEncryptionEnabled()) {
+        if (!isClient && !sslFactory.hasKeyManagers()) {
+          throw new IllegalStateException("Not a client connection and no keys 
configured");
+        }
+
+        SslHandler sslHandler;
+        try {
+          sslHandler = new SslHandler(sslFactory.createSSLEngine(isClient, 
channel.alloc()));
+        } catch (Exception e) {
+          throw new IllegalStateException("Error creating Netty SslHandler", 
e);
+        }
+        pipeline.addFirst("NettySslEncryptionHandler", sslHandler);
+        // Cannot use zero-copy with HTTPS, so we add in our 
ChunkedWriteHandler just before the
+        // MessageEncoder
+        pipeline.addLast("chunkedWriter", new ChunkedWriteHandler());
+      }
+
       if (channelsLimiter != null) {
         pipeline.addLast("limiter", channelsLimiter);
       }
       TransportChannelHandler channelHandler = createChannelHandler(channel, 
resolvedMsgHandler);
       pipeline
-          .addLast("encoder", ENCODER)
+          .addLast("encoder", sslEncryptionEnabled() ? SSL_ENCODER : ENCODER)
           .addLast(FrameDecoder.HANDLER_NAME, decoder)
           .addLast(
               "idleStateHandler",
@@ -174,6 +207,35 @@ public class TransportContext {
     }
   }
 
+  private SSLFactory createSslFactory() {
+    if (conf.sslEnabled()) {
+      if (conf.sslEnabledAndKeysAreValid()) {
+        return new SSLFactory.Builder()
+            .requestedProtocol(conf.sslProtocol())
+            .requestedCiphers(conf.sslRequestedCiphers())
+            .keyStore(conf.sslKeyStore(), conf.sslKeyStorePassword())
+            .trustStore(
+                conf.sslTrustStore(),
+                conf.sslTrustStorePassword(),
+                conf.sslTrustStoreReloadingEnabled(),
+                conf.sslTrustStoreReloadIntervalMs())
+            .build();
+      } else {
+        logger.error(
+            "SSL encryption enabled but keys not found for "
+                + conf.getModuleName()
+                + "! Please ensure the configured keys are present.");
+        throw new IllegalArgumentException(
+            conf.getModuleName()
+                + " SSL encryption enabled for "
+                + conf.getModuleName()
+                + " but keys not found!");
+      }
+    } else {
+      return null;
+    }
+  }
+
   private TransportChannelHandler createChannelHandler(
       Channel channel, BaseMessageHandler msgHandler) {
     TransportResponseHandler responseHandler = new 
TransportResponseHandler(conf, channel);
@@ -198,4 +260,11 @@ public class TransportContext {
   public BaseMessageHandler getMsgHandler() {
     return msgHandler;
   }
+
+  @Override
+  public void close() {
+    if (sslFactory != null) {
+      sslFactory.destroy();
+    }
+  }
 }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
index f51d04a57..26191be26 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
@@ -34,6 +34,9 @@ import io.netty.bootstrap.Bootstrap;
 import io.netty.buffer.ByteBufAllocator;
 import io.netty.channel.*;
 import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -81,6 +84,7 @@ public class TransportClientFactory implements Closeable {
   private final int numConnectionsPerPeer;
 
   private final int connectTimeoutMs;
+  private final int connectionTimeoutMs;
 
   private final int receiveBuf;
 
@@ -97,6 +101,7 @@ public class TransportClientFactory implements Closeable {
     this.connectionPool = JavaUtils.newConcurrentHashMap();
     this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
     this.connectTimeoutMs = conf.connectTimeoutMs();
+    this.connectionTimeoutMs = conf.connectionTimeoutMs();
     this.receiveBuf = conf.receiveBuf();
     this.sendBuf = conf.sendBuf();
     this.rand = new Random();
@@ -237,7 +242,7 @@ public class TransportClientFactory implements Closeable {
         new ChannelInitializer<SocketChannel>() {
           @Override
           public void initChannel(SocketChannel ch) {
-            TransportChannelHandler clientHandler = 
context.initializePipeline(ch, decoder);
+            TransportChannelHandler clientHandler = 
context.initializePipeline(ch, decoder, true);
             clientRef.set(clientHandler.getClient());
             channelRef.set(ch);
           }
@@ -252,6 +257,32 @@ public class TransportClientFactory implements Closeable {
     } else if (cf.cause() != null) {
       throw new CelebornIOException(String.format("Failed to connect to %s", 
address), cf.cause());
     }
+    if (context.sslEncryptionEnabled()) {
+      final SslHandler sslHandler = 
cf.channel().pipeline().get(SslHandler.class);
+      Future<Channel> future =
+          sslHandler
+              .handshakeFuture()
+              .addListener(
+                  new GenericFutureListener<Future<Channel>>() {
+                    @Override
+                    public void operationComplete(final Future<Channel> 
handshakeFuture) {
+                      if (handshakeFuture.isSuccess()) {
+                        logger.debug("successfully completed TLS handshake to 
{}", address);
+                      } else {
+                        logger.info(
+                            "failed to complete TLS handshake to {}",
+                            address,
+                            handshakeFuture.cause());
+                        cf.channel().close();
+                      }
+                    }
+                  });
+      if (!future.await(connectionTimeoutMs)) {
+        cf.channel().close();
+        throw new IOException(
+            String.format("Failed to connect to %s within connection timeout", 
address));
+      }
+    }
 
     TransportClient client = clientRef.get();
     assert client != null : "Channel future completed successfully with null 
client";
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
index 53f7b7169..d849327df 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
@@ -142,7 +142,7 @@ public class TransportServer implements Closeable {
                   "Adding bootstrap to TransportServer {}.", 
bootstrap.getClass().getName());
               baseHandler = bootstrap.doBootstrap(ch, baseHandler);
             }
-            context.initializePipeline(ch, baseHandler);
+            context.initializePipeline(ch, baseHandler, false);
           }
         });
   }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java 
b/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
new file mode 100644
index 000000000..443db6a9c
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.ssl;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.security.GeneralSecurityException;
+import java.security.KeyStore;
+import java.security.KeyStoreException;
+import java.security.NoSuchAlgorithmException;
+import java.security.UnrecoverableKeyException;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.TrustManagerFactory;
+import javax.net.ssl.X509TrustManager;
+
+import com.google.common.io.Files;
+import io.netty.buffer.ByteBufAllocator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.util.JavaUtils;
+
+/**
+ * SSLFactory to initialize and configure use of JSSE for SSL in Celeborn.
+ *
+ * <p>Note: code was initially copied from Apache Spark.
+ */
+public class SSLFactory {
+  private static final Logger logger = 
LoggerFactory.getLogger(SSLFactory.class);
+
+  /** For a configuration specifying keystore/truststore files */
+  private SSLContext jdkSslContext;
+
+  private KeyManager[] keyManagers;
+  private TrustManager[] trustManagers;
+  private String requestedProtocol;
+  private String[] requestedCiphers;
+
+  private SSLFactory(final Builder b) {
+    this.requestedProtocol = b.requestedProtocol;
+    this.requestedCiphers = b.requestedCiphers;
+    try {
+      initJdkSslContext(b);
+    } catch (Exception e) {
+      throw new RuntimeException("SSLFactory creation failed", e);
+    }
+  }
+
+  private void initJdkSslContext(final Builder b) throws IOException, 
GeneralSecurityException {
+    this.keyManagers =
+        null != b.keyStore ? keyManagers(b.keyStore, b.keyPassword, 
b.keyStorePassword) : null;
+    this.trustManagers =
+        trustStoreManagers(
+            b.trustStore, b.trustStorePassword,
+            b.trustStoreReloadingEnabled, b.trustStoreReloadIntervalMs);
+    this.jdkSslContext = createSSLContext(requestedProtocol, keyManagers, 
trustManagers);
+  }
+
+  public boolean hasKeyManagers() {
+    return null != keyManagers;
+  }
+
+  public void destroy() {
+    if (trustManagers != null) {
+      for (int i = 0; i < trustManagers.length; i++) {
+        if (trustManagers[i] instanceof ReloadingX509TrustManager) {
+          try {
+            ((ReloadingX509TrustManager) trustManagers[i]).destroy();
+          } catch (InterruptedException ex) {
+            logger.info("Interrupted while destroying trust manager: {}", ex, 
ex);
+          }
+        }
+      }
+      trustManagers = null;
+    }
+
+    keyManagers = null;
+    jdkSslContext = null;
+    requestedProtocol = null;
+    requestedCiphers = null;
+  }
+
+  /** Builder class to construct instances of {@link SSLFactory} with specific 
options */
+  public static class Builder {
+    private String requestedProtocol;
+    private String[] requestedCiphers;
+    private File keyStore;
+    private String keyStorePassword;
+    private String keyPassword;
+    private File trustStore;
+    private String trustStorePassword;
+    private boolean trustStoreReloadingEnabled;
+    private int trustStoreReloadIntervalMs;
+
+    /**
+     * Sets the requested protocol, i.e., "TLSv1.2", "TLSv1.1", etc
+     *
+     * @param requestedProtocol The requested protocol
+     * @return The builder object
+     */
+    public Builder requestedProtocol(String requestedProtocol) {
+      this.requestedProtocol = requestedProtocol == null ? "TLSv1.3" : 
requestedProtocol;
+      return this;
+    }
+
+    /**
+     * Sets the requested cipher suites, i.e., 
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", etc
+     *
+     * @param requestedCiphers The requested ciphers
+     * @return The builder object
+     */
+    public Builder requestedCiphers(String[] requestedCiphers) {
+      this.requestedCiphers = requestedCiphers;
+      return this;
+    }
+
+    /**
+     * Sets the Keystore and Keystore password
+     *
+     * @param keyStore The key store file to use
+     * @param keyStorePassword The password for the key store
+     * @return The builder object
+     */
+    public Builder keyStore(File keyStore, String keyStorePassword) {
+      this.keyStore = keyStore;
+      this.keyStorePassword = keyStorePassword;
+      return this;
+    }
+
+    /**
+     * Sets the key password
+     *
+     * @param keyPassword The password for the private key in the key store
+     * @return The builder object
+     */
+    public Builder keyPassword(String keyPassword) {
+      this.keyPassword = keyPassword;
+      return this;
+    }
+
+    /**
+     * Sets the trust-store, trust-store password, whether to use a Reloading 
TrustStore, and the
+     * trust-store reload interval, if enabled
+     *
+     * @param trustStore The trust store file to use
+     * @param trustStorePassword The password for the trust store
+     * @param trustStoreReloadingEnabled Whether trust store reloading is 
enabled
+     * @param trustStoreReloadIntervalMs The interval at which to reload the 
trust store file
+     * @return The builder object
+     */
+    public Builder trustStore(
+        File trustStore,
+        String trustStorePassword,
+        boolean trustStoreReloadingEnabled,
+        int trustStoreReloadIntervalMs) {
+      this.trustStore = trustStore;
+      this.trustStorePassword = trustStorePassword;
+      this.trustStoreReloadingEnabled = trustStoreReloadingEnabled;
+      this.trustStoreReloadIntervalMs = trustStoreReloadIntervalMs;
+      return this;
+    }
+
+    /**
+     * Builds our {@link SSLFactory}
+     *
+     * @return The built {@link SSLFactory}
+     */
+    public SSLFactory build() {
+      return new SSLFactory(this);
+    }
+  }
+
+  /**
+   * Returns an initialized {@link SSLContext}
+   *
+   * @param requestedProtocol The requested protocol to use
+   * @param keyManagers The list of key managers to use
+   * @param trustManagers The list of trust managers to use
+   * @return The built {@link SSLContext}
+   * @throws GeneralSecurityException
+   */
+  private static SSLContext createSSLContext(
+      String requestedProtocol, KeyManager[] keyManagers, TrustManager[] 
trustManagers)
+      throws GeneralSecurityException {
+    SSLContext sslContext = SSLContext.getInstance(requestedProtocol);
+    sslContext.init(keyManagers, trustManagers, null);
+    return sslContext;
+  }
+
+  /**
+   * Creates a new {@link SSLEngine}. Note that currently client auth is not 
supported
+   *
+   * @param isClient Whether the engine is used in a client context
+   * @param allocator The {@link ByteBufAllocator to use}
+   * @return A valid {@link SSLEngine}.
+   */
+  public SSLEngine createSSLEngine(boolean isClient, ByteBufAllocator 
allocator) {
+    SSLEngine engine = createEngine(isClient, allocator);
+    engine.setUseClientMode(isClient);
+    engine.setWantClientAuth(true);
+    engine.setEnabledProtocols(enabledProtocols(engine, requestedProtocol));
+    engine.setEnabledCipherSuites(enabledCipherSuites(engine, 
requestedCiphers));
+    return engine;
+  }
+
+  private SSLEngine createEngine(boolean isClient, ByteBufAllocator allocator) 
{
+    return jdkSslContext.createSSLEngine();
+  }
+
+  private static final X509Certificate[] EMPTY_CERT_ARRAY = new 
X509Certificate[0];
+
+  private static TrustManager[] credulousTrustStoreManagers() {
+    return new TrustManager[] {
+      new X509TrustManager() {
+        @Override
+        public void checkClientTrusted(X509Certificate[] x509Certificates, 
String s)
+            throws CertificateException {}
+
+        @Override
+        public void checkServerTrusted(X509Certificate[] x509Certificates, 
String s)
+            throws CertificateException {}
+
+        @Override
+        public X509Certificate[] getAcceptedIssuers() {
+          return EMPTY_CERT_ARRAY;
+        }
+      }
+    };
+  }
+
+  private static TrustManager[] trustStoreManagers(
+      File trustStore,
+      String trustStorePassword,
+      boolean trustStoreReloadingEnabled,
+      int trustStoreReloadIntervalMs)
+      throws IOException, GeneralSecurityException {
+    if (trustStore == null || !trustStore.exists()) {
+      return credulousTrustStoreManagers();
+    } else {
+      if (trustStoreReloadingEnabled) {
+        ReloadingX509TrustManager reloading =
+            new ReloadingX509TrustManager(
+                KeyStore.getDefaultType(),
+                trustStore,
+                trustStorePassword,
+                trustStoreReloadIntervalMs);
+        reloading.init();
+        return new TrustManager[] {reloading};
+      } else {
+        return defaultTrustManagers(trustStore, trustStorePassword);
+      }
+    }
+  }
+
+  private static TrustManager[] defaultTrustManagers(File trustStore, String 
trustStorePassword)
+      throws IOException, KeyStoreException, CertificateException, 
NoSuchAlgorithmException {
+    try (InputStream input = Files.asByteSource(trustStore).openStream()) {
+      KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
+      char[] passwordCharacters =
+          trustStorePassword != null ? trustStorePassword.toCharArray() : null;
+      ks.load(input, passwordCharacters);
+      TrustManagerFactory tmf =
+          
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+      tmf.init(ks);
+      return tmf.getTrustManagers();
+    }
+  }
+
+  private static KeyManager[] keyManagers(
+      File keyStore, String keyPassword, String keyStorePassword)
+      throws NoSuchAlgorithmException, CertificateException, 
KeyStoreException, IOException,
+          UnrecoverableKeyException {
+    KeyManagerFactory factory =
+        KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+    char[] keyStorePasswordChars = keyStorePassword != null ? 
keyStorePassword.toCharArray() : null;
+    char[] keyPasswordChars =
+        keyPassword != null ? keyPassword.toCharArray() : 
keyStorePasswordChars;
+    factory.init(loadKeyStore(keyStore, keyStorePasswordChars), 
keyPasswordChars);
+    return factory.getKeyManagers();
+  }
+
+  private static KeyStore loadKeyStore(File keyStore, char[] keyStorePassword)
+      throws KeyStoreException, IOException, CertificateException, 
NoSuchAlgorithmException {
+    if (keyStore == null) {
+      throw new KeyStoreException(
+          "keyStore cannot be null. Please configure 
celeborn.ssl.<module>.keyStore");
+    }
+
+    KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
+    FileInputStream fin = new FileInputStream(keyStore);
+    try {
+      ks.load(fin, keyStorePassword);
+      return ks;
+    } finally {
+      JavaUtils.closeQuietly(fin);
+    }
+  }
+
+  private static String[] enabledProtocols(SSLEngine engine, String 
requestedProtocol) {
+    String[] supportedProtocols = engine.getSupportedProtocols();
+    String[] defaultProtocols = {"TLSv1.3", "TLSv1.2"};
+    String[] enabledProtocols =
+        ((requestedProtocol == null || requestedProtocol.isEmpty())
+            ? defaultProtocols
+            : new String[] {requestedProtocol});
+
+    List<String> protocols = addIfSupported(supportedProtocols, 
enabledProtocols);
+    if (!protocols.isEmpty()) {
+      return protocols.toArray(new String[protocols.size()]);
+    } else {
+      return supportedProtocols;
+    }
+  }
+
+  private static String[] enabledCipherSuites(
+      String[] supportedCiphers, String[] defaultCiphers, String[] 
requestedCiphers) {
+    String[] baseCiphers =
+        new String[] {
+          // We take ciphers from the mozilla modern list first (for TLS 1.3):
+          // https://wiki.mozilla.org/Security/Server_Side_TLS
+          "TLS_CHACHA20_POLY1305_SHA256",
+          "TLS_AES_128_GCM_SHA256",
+          "TLS_AES_256_GCM_SHA384",
+          // Next we have the TLS1.2 ciphers for intermediate compatibility 
(since JDK8 does not
+          // support TLS1.3)
+          "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
+          "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
+          "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
+          "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
+          "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
+          "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
+          "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
+          "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384"
+        };
+    String[] enabledCiphers =
+        ((requestedCiphers == null || requestedCiphers.length == 0)
+            ? baseCiphers
+            : requestedCiphers);
+
+    List<String> ciphers = addIfSupported(supportedCiphers, enabledCiphers);
+    if (!ciphers.isEmpty()) {
+      return ciphers.toArray(new String[ciphers.size()]);
+    } else {
+      // Use the default from JDK as fallback.
+      return defaultCiphers;
+    }
+  }
+
+  private static String[] enabledCipherSuites(SSLEngine engine, String[] 
requestedCiphers) {
+    return enabledCipherSuites(
+        engine.getSupportedCipherSuites(), engine.getEnabledCipherSuites(), 
requestedCiphers);
+  }
+
+  private static List<String> addIfSupported(String[] supported, String... 
names) {
+    List<String> enabled = new ArrayList<>();
+    Set<String> supportedSet = new HashSet<>(Arrays.asList(supported));
+    for (String n : names) {
+      if (supportedSet.contains(n)) {
+        enabled.add(n);
+      }
+    }
+    return enabled;
+  }
+}
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
index b0f470efe..151a79104 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
@@ -59,7 +59,8 @@ class NettyRpcEnv(
 
   private var worker: RpcEndpoint = null
 
-  private val transportContext =
+  // Visible for tests
+  private[netty] val transportContext =
     new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
 
   private def createClientBootstraps(): 
java.util.List[TransportClientBootstrap] = {
@@ -342,6 +343,9 @@ class NettyRpcEnv(
     if (clientFactory != null) {
       clientFactory.close()
     }
+    if (null != transportContext) {
+      transportContext.close();
+    }
     if (clientConnectionExecutor != null) {
       clientConnectionExecutor.shutdownNow()
     }
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
index ec1834b4d..76cc0dc2c 100644
--- 
a/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
@@ -43,7 +43,9 @@ import org.apache.celeborn.common.network.util.TransportConf;
 import org.apache.celeborn.common.util.JavaUtils;
 
 public class RpcIntegrationSuiteJ {
+  static final String TEST_MODULE = "shuffle";
   static TransportConf conf;
+  static TransportContext context;
   static TransportServer server;
   static TransportClientFactory clientFactory;
   static BaseMessageHandler handler;
@@ -52,7 +54,11 @@ public class RpcIntegrationSuiteJ {
 
   @BeforeClass
   public static void setUp() throws Exception {
-    conf = new TransportConf("shuffle", new CelebornConf());
+    initialize((new CelebornConf()));
+  }
+
+  static void initialize(CelebornConf celebornConf) throws Exception {
+    conf = new TransportConf(TEST_MODULE, celebornConf);
     testData = new StreamTestHelper();
     handler =
         new BaseMessageHandler() {
@@ -91,7 +97,7 @@ public class RpcIntegrationSuiteJ {
             return true;
           }
         };
-    TransportContext context = new TransportContext(conf, handler);
+    context = new TransportContext(conf, handler);
     server = context.createServer();
     clientFactory = context.createClientFactory();
     oneWayMsgs = new ArrayList<>();
@@ -101,6 +107,7 @@ public class RpcIntegrationSuiteJ {
   public static void tearDown() {
     server.close();
     clientFactory.close();
+    context.close();
     testData.cleanup();
   }
 
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/SSLRpcIntegrationSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/SSLRpcIntegrationSuiteJ.java
new file mode 100644
index 000000000..5c77f995a
--- /dev/null
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/SSLRpcIntegrationSuiteJ.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network;
+
+import static org.junit.Assert.assertTrue;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+
+public class SSLRpcIntegrationSuiteJ extends RpcIntegrationSuiteJ {
+
+  @BeforeClass
+  public static void setUp() throws Exception {
+    // set up SSL for TEST_MODULE
+    RpcIntegrationSuiteJ.initialize(
+        TestHelper.updateCelebornConfWithMap(
+            new CelebornConf(), 
SslSampleConfigs.createDefaultConfigMapForModule(TEST_MODULE)));
+  }
+
+  @AfterClass
+  public static void tearDown() {
+    RpcIntegrationSuiteJ.tearDown();
+  }
+
+  @Test
+  public void validateSslConfig() {
+    // this is to ensure ssl config has been applied.
+    assertTrue(conf.sslEnabled());
+  }
+}
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/SSLTransportClientFactorySuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/SSLTransportClientFactorySuiteJ.java
new file mode 100644
index 000000000..8c4884373
--- /dev/null
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/SSLTransportClientFactorySuiteJ.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network;
+
+import static org.junit.Assert.assertTrue;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+
+public class SSLTransportClientFactorySuiteJ extends 
TransportClientFactorySuiteJ {
+
+  @Before
+  public void setUp() {
+    // set up SSL for TEST_MODULE
+    doSetup(
+        TestHelper.updateCelebornConfWithMap(
+            new CelebornConf(), 
SslSampleConfigs.createDefaultConfigMapForModule(TEST_MODULE)));
+  }
+
+  @After
+  public void tearDown() {
+    super.tearDown();
+  }
+
+  @Test
+  public void validateSslConfig() {
+    // this is to ensure ssl config has been applied.
+    assertTrue(getTransportContextConf().sslEnabled());
+  }
+}
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java
index 022af8178..26a9b4885 100644
--- 
a/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java
@@ -37,23 +37,36 @@ import 
org.apache.celeborn.common.network.util.TransportConf;
 import org.apache.celeborn.common.util.JavaUtils;
 
 public class TransportClientFactorySuiteJ {
+
+  static final String TEST_MODULE = "shuffle";
+
   private TransportContext context;
   private TransportServer server1;
   private TransportServer server2;
 
-  @Before
-  public void setUp() {
-    TransportConf conf = new TransportConf("shuffle", new CelebornConf());
+  protected void doSetup(CelebornConf celebornConf) {
+    TransportConf conf = new TransportConf(TEST_MODULE, celebornConf);
     BaseMessageHandler handler = new BaseMessageHandler();
     context = new TransportContext(conf, handler);
     server1 = context.createServer();
     server2 = context.createServer();
   }
 
+  @Before
+  public void setUp() {
+    doSetup(new CelebornConf());
+  }
+
+  // for validation in subclasses
+  TransportConf getTransportContextConf() {
+    return context.getConf();
+  }
+
   @After
   public void tearDown() {
     JavaUtils.closeQuietly(server1);
     JavaUtils.closeQuietly(server2);
+    JavaUtils.closeQuietly(context);
   }
 
   /**
@@ -68,7 +81,7 @@ public class TransportClientFactorySuiteJ {
 
     CelebornConf _conf = new CelebornConf();
     _conf.set("celeborn.shuffle.io.numConnectionsPerPeer", 
Integer.toString(maxConnections));
-    TransportConf conf = new TransportConf("shuffle", _conf);
+    TransportConf conf = new TransportConf(TEST_MODULE, _conf);
 
     BaseMessageHandler handler = new BaseMessageHandler();
     TransportContext context = new TransportContext(conf, handler);
@@ -114,6 +127,7 @@ public class TransportClientFactorySuiteJ {
     }
 
     factory.close();
+    context.close();
   }
 
   @Test
@@ -177,7 +191,7 @@ public class TransportClientFactorySuiteJ {
   public void closeIdleConnectionForRequestTimeOut() throws IOException, 
InterruptedException {
     CelebornConf _conf = new CelebornConf();
     _conf.set("celeborn.shuffle.io.connectionTimeout", "1s");
-    TransportConf conf = new TransportConf("shuffle", _conf);
+    TransportConf conf = new TransportConf(TEST_MODULE, _conf);
     TransportContext context = new TransportContext(conf, new 
BaseMessageHandler(), true);
     try (TransportClientFactory factory = context.createClientFactory()) {
       TransportClient c1 = factory.createClient(getLocalHost(), 
server1.getPort());
@@ -188,6 +202,7 @@ public class TransportClientFactorySuiteJ {
       }
       assertFalse(c1.isActive());
     }
+    context.close();
   }
 
   @Test(expected = IOException.class)
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/sasl/SaslTestBase.java
 
b/common/src/test/java/org/apache/celeborn/common/network/sasl/SaslTestBase.java
index c6ca51831..6a34adb08 100644
--- 
a/common/src/test/java/org/apache/celeborn/common/network/sasl/SaslTestBase.java
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/sasl/SaslTestBase.java
@@ -137,6 +137,9 @@ public class SaslTestBase {
       if (server != null) {
         server.close();
       }
+      if (null != ctx) {
+        ctx.close();
+      }
     }
   }
 }
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslConnectivitySuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslConnectivitySuiteJ.java
new file mode 100644
index 000000000..5f8b1902c
--- /dev/null
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslConnectivitySuiteJ.java
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.ssl;
+
+import static org.apache.celeborn.common.util.JavaUtils.getLocalHost;
+import static org.junit.Assert.*;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Map;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.TestHelper;
+import org.apache.celeborn.common.network.TransportContext;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.server.BaseMessageHandler;
+import org.apache.celeborn.common.network.server.TransportServer;
+import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.util.JavaUtils;
+
+/**
+ * A few negative tests to ensure that a non SSL client cant talk to an SSL 
server and vice versa.
+ * Also, a few tests to ensure non-SSL client and servers can talk to each 
other, SSL client and
+ * server also can talk to each other.
+ */
+public class SslConnectivitySuiteJ {
+
+  private static final String TEST_MODULE = "rpc";
+
+  private static final String RESPONSE_PREFIX = "Test-prefix...";
+  private static final String RESPONSE_SUFFIX = "...Suffix";
+
+  private static final TestBaseMessageHandler DEFAULT_HANDLER = new 
TestBaseMessageHandler();
+
+  private static TransportConf createTransportConf(
+      String module,
+      boolean enableSsl,
+      boolean useDefault,
+      Function<CelebornConf, CelebornConf> postProcessConf) {
+
+    CelebornConf celebornConf = new CelebornConf();
+    // in case the default gets flipped to true in future
+    celebornConf.set("celeborn.ssl." + module + ".enabled", "false");
+    if (enableSsl) {
+      Map<String, String> configMap;
+      if (useDefault) {
+        configMap = SslSampleConfigs.createDefaultConfigMapForModule(module);
+      } else {
+        configMap = SslSampleConfigs.createAnotherConfigMapForModule(module);
+      }
+      TestHelper.updateCelebornConfWithMap(celebornConf, configMap);
+    }
+
+    celebornConf = postProcessConf.apply(celebornConf);
+
+    TransportConf conf = new TransportConf(module, celebornConf);
+    assertEquals(enableSsl, conf.sslEnabled());
+
+    return conf;
+  }
+
+  private static class TestTransportState implements Closeable {
+    final TransportContext serverContext;
+    final TransportContext clientContext;
+    final TransportServer server;
+    final TransportClientFactory clientFactory;
+
+    TestTransportState(
+        BaseMessageHandler handler, TransportConf serverConf, TransportConf 
clientConf) {
+      this.serverContext = new TransportContext(serverConf, handler);
+      this.clientContext = new TransportContext(clientConf, handler);
+
+      this.server = serverContext.createServer();
+      this.clientFactory = clientContext.createClientFactory();
+    }
+
+    TransportClient createClient() throws IOException, InterruptedException {
+      return clientFactory.createClient(getLocalHost(), server.getPort());
+    }
+
+    @Override
+    public void close() throws IOException {
+      JavaUtils.closeQuietly(server);
+      JavaUtils.closeQuietly(clientFactory);
+      JavaUtils.closeQuietly(serverContext);
+      JavaUtils.closeQuietly(clientContext);
+    }
+  }
+
+  private static class TestBaseMessageHandler extends BaseMessageHandler {
+
+    @Override
+    public void receive(
+        TransportClient client, RequestMessage requestMessage, 
RpcResponseCallback callback) {
+      try {
+        String msg = 
JavaUtils.bytesToString(requestMessage.body().nioByteBuffer());
+        String response = RESPONSE_PREFIX + msg + RESPONSE_SUFFIX;
+        callback.onSuccess(JavaUtils.stringToBytes(response));
+      } catch (Exception e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public boolean checkRegistered() {
+      return true;
+    }
+  }
+
+  // Pair<Success, Failure> ... should be an Either actually.
+
+  private Pair<String, String> sendRPC(TransportClient client, String message, 
boolean canTimeout)
+      throws Exception {
+    final Semaphore sem = new Semaphore(0);
+
+    final AtomicReference<String> response = new AtomicReference<>(null);
+    final AtomicReference<String> errorResponse = new AtomicReference<>(null);
+
+    RpcResponseCallback callback =
+        new RpcResponseCallback() {
+          @Override
+          public void onSuccess(ByteBuffer message) {
+            String res = JavaUtils.bytesToString(message);
+            response.set(res);
+            sem.release();
+          }
+
+          @Override
+          public void onFailure(Throwable e) {
+            errorResponse.set(e.getMessage());
+            sem.release();
+          }
+        };
+
+    client.sendRpc(JavaUtils.stringToBytes(message), callback);
+
+    if (!sem.tryAcquire(1, 5, TimeUnit.SECONDS)) {
+      if (canTimeout) {
+        throw new IOException("Timed out sending rpc message");
+      } else {
+        fail("Timeout getting response from the server");
+      }
+    }
+
+    return Pair.of(response.get(), errorResponse.get());
+  }
+
+  // A basic validation test to check if non-ssl client and non-ssl server can 
talk to each other.
+  // This not only validates non-SSL flow, but also is a check to verify that 
the negative
+  // tests are not indicating any other error.
+  @Test
+  public void testNormalConnectivityWorks() throws Exception {
+    testSuccessfulConnectivity(
+        false,
+        // does not matter what these are set to
+        true,
+        true,
+        Function.identity(),
+        Function.identity());
+  }
+
+  // Both server and client are on SSL, and should be able to successfully 
communicate
+  // This is the SSL version of testNormalConnectivityWorks above.
+  @Test
+  public void testBasicSslClientConnectivityWorks() throws Exception {
+
+    final Function<CelebornConf, CelebornConf> updateClientConf =
+        conf -> {
+          // ignore incoming conf, and return a new which only has ssl enabled 
= true
+          // This is essentially testing two things:
+          // a) client can talk SSL to a server, and does not need anything 
else to be specified
+          // b) When no truststore is configured at client, it is trusts all 
server certs
+          CelebornConf newConf = new CelebornConf();
+          newConf.set("celeborn.ssl." + TEST_MODULE + ".enabled", "true");
+          return newConf;
+        };
+
+    // primaryConfigForClient param does not matter - we are completely 
overriding it above
+    // in updateClientConf. Adding both just for completeness sake.
+    testSuccessfulConnectivity(true, true, true, Function.identity(), 
updateClientConf);
+    testSuccessfulConnectivity(true, true, false, Function.identity(), 
updateClientConf);
+
+    // just for validation, this should fail (ssl for client, plain for 
server) ...
+    testConnectivityFailure(false, true, false, false, Function.identity(), 
updateClientConf);
+  }
+
+  // Only SSL client can talk to a SSL server.
+  @Test
+  public void testSslServerNormalClientFails() throws Exception {
+    // Will fail for both primary and seconday jks (primaryConfigForServer) - 
adding
+    // both just for completeness
+    testConnectivityFailure(true, false, true, true, Function.identity(), 
Function.identity());
+    testConnectivityFailure(true, false, false, true, Function.identity(), 
Function.identity());
+  }
+
+  // Only non-SSL client can talk to SSL server
+  @Test
+  public void testSslClientNormalServerFails() throws Exception {
+    // Will fail for both primaryConfigForClient - adding both just for 
completeness
+    testConnectivityFailure(false, true, false, false, Function.identity(), 
Function.identity());
+    testConnectivityFailure(false, true, false, true, Function.identity(), 
Function.identity());
+  }
+
+  @Test
+  public void testUntrustedServerCertFails() throws Exception {
+    final String trustStoreKey = "celeborn.ssl." + TEST_MODULE + ".trustStore";
+
+    final Function<CelebornConf, CelebornConf> updateConf =
+        conf -> {
+          assertTrue(conf.getOption(trustStoreKey).isDefined());
+          conf.set(trustStoreKey, SslSampleConfigs.TRUST_STORE_WITHOUT_CA);
+          return conf;
+        };
+
+    // will fail for all combinations - since we dont have the CA's in the 
truststore
+    testConnectivityFailure(true, true, true, false, updateConf, updateConf);
+    testConnectivityFailure(true, true, true, true, updateConf, updateConf);
+    testConnectivityFailure(true, true, false, true, updateConf, updateConf);
+    testConnectivityFailure(true, true, false, false, updateConf, updateConf);
+  }
+
+  // This is a variant of testUntrustedServerCertFails - where the server does 
not trust the
+  // client cert, and the client does provide a cert.
+  @Test
+  public void testUntrustedClientCertFails() throws Exception {
+    final String trustStoreKey = "celeborn.ssl." + TEST_MODULE + ".trustStore";
+
+    final Function<CelebornConf, CelebornConf> updateConf =
+        conf -> {
+          assertTrue(conf.getOption(trustStoreKey).isDefined());
+          conf.set(trustStoreKey, SslSampleConfigs.TRUST_STORE_WITHOUT_CA);
+          return conf;
+        };
+
+    // will fail for all combinations - since server does not have the client 
cert's CA
+    // in its truststore
+    testConnectivityFailure(true, true, true, false, updateConf, 
Function.identity());
+    testConnectivityFailure(true, true, true, true, updateConf, 
Function.identity());
+    testConnectivityFailure(true, true, false, true, updateConf, 
Function.identity());
+    testConnectivityFailure(true, true, false, false, updateConf, 
Function.identity());
+  }
+
+  @Test
+  public void testUntrustedServerCertWorksIfTrustStoreDisabled() throws 
Exception {
+    // Same as testUntrustedServerCertFails, but remove the truststore - which 
should result in
+    // accepting all certs. Note, for jks at client side
+
+    final String trustStoreKey = "celeborn.ssl." + TEST_MODULE + ".trustStore";
+
+    final Function<CelebornConf, CelebornConf> updateConf =
+        conf -> {
+          assertTrue(conf.getOption(trustStoreKey).isDefined());
+          conf.unset(trustStoreKey);
+          assertNull(new TransportConf(TEST_MODULE, conf).sslTrustStore());
+          return conf;
+        };
+
+    // checking nettyssl == false at both client and server does not make 
sense in this context
+    // it is the same cert for both :)
+    testSuccessfulConnectivity(true, true, true, updateConf, updateConf);
+    testSuccessfulConnectivity(true, true, false, updateConf, updateConf);
+    testSuccessfulConnectivity(true, false, true, updateConf, updateConf);
+    testSuccessfulConnectivity(true, false, false, updateConf, updateConf);
+  }
+
+  private void testSuccessfulConnectivity(
+      boolean enableSsl,
+      boolean primaryConfigForServer,
+      boolean primaryConfigForClient,
+      Function<CelebornConf, CelebornConf> postProcessServerConf,
+      Function<CelebornConf, CelebornConf> postProcessClientConf)
+      throws Exception {
+    try (TestTransportState state =
+            new TestTransportState(
+                DEFAULT_HANDLER,
+                createTransportConf(
+                    TEST_MODULE, enableSsl, primaryConfigForServer, 
postProcessServerConf),
+                createTransportConf(
+                    TEST_MODULE, enableSsl, primaryConfigForClient, 
postProcessClientConf));
+        TransportClient client = state.createClient()) {
+
+      String msg = " hi ";
+      Pair<String, String> response = sendRPC(client, msg, false);
+      assertNotNull("Failed ? " + response.getRight(), response.getLeft());
+      assertNull(response.getRight());
+      assertEquals(RESPONSE_PREFIX + msg + RESPONSE_SUFFIX, 
response.getLeft());
+    }
+  }
+
+  private void testConnectivityFailure(
+      boolean serverSsl,
+      boolean clientSsl,
+      boolean primaryConfigForServer,
+      boolean primaryConfigForClient,
+      Function<CelebornConf, CelebornConf> postProcessServerConf,
+      Function<CelebornConf, CelebornConf> postProcessClientConf)
+      throws Exception {
+    try (TestTransportState state =
+            new TestTransportState(
+                DEFAULT_HANDLER,
+                createTransportConf(
+                    TEST_MODULE, serverSsl, primaryConfigForServer, 
postProcessServerConf),
+                createTransportConf(
+                    TEST_MODULE, clientSsl, primaryConfigForClient, 
postProcessClientConf));
+        TransportClient client = state.createClient()) {
+
+      String msg = " hi ";
+      Pair<String, String> response = sendRPC(client, msg, true);
+      assertNull(response.getLeft());
+      assertNotNull(response.getRight());
+    } catch (IOException ioEx) {
+      // this is fine - expected to fail
+    }
+  }
+}
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
 
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
index 653709a31..c33baf92a 100644
--- 
a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
@@ -62,7 +62,6 @@ public class SslSampleConfigs {
     Map<String, String> confMap = new HashMap<>();
     confMap.put("celeborn.ssl." + module + ".enabled", "true");
     confMap.put("celeborn.ssl." + module + ".trustStoreReloadingEnabled", 
"false");
-    confMap.put("celeborn.ssl." + module + ".openSslEnabled", "false");
     confMap.put("celeborn.ssl." + module + ".trustStoreReloadIntervalMs", 
"10000");
     if (forDefault) {
       confMap.put("celeborn.ssl." + module + ".keyStore", 
DEFAULT_KEY_STORE_PATH);
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/SSLNettyRpcEnvSuite.scala
 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/SSLNettyRpcEnvSuite.scala
new file mode 100644
index 000000000..79de7f313
--- /dev/null
+++ 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/SSLNettyRpcEnvSuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc.netty
+
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.network.TestHelper
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs
+import org.apache.celeborn.common.protocol.TransportModuleConstants
+
+class SSLNettyRpcEnvSuite extends NettyRpcEnvSuite {
+
+  override def createCelebornConf(): CelebornConf = {
+    val conf = super.createCelebornConf()
+    TestHelper.updateCelebornConfWithMap(
+      conf,
+      
SslSampleConfigs.createDefaultConfigMapForModule(TransportModuleConstants.RPC_MODULE))
+    conf
+  }
+
+  test("verify rpc env is using SSL") {
+    env match {
+      case nettyRpcEnv: NettyRpcEnv =>
+        assert(nettyRpcEnv.transportContext.getConf.sslEnabled())
+        assert(nettyRpcEnv.transportContext.sslEncryptionEnabled())
+      case _ =>
+        throw new IllegalArgumentException("Expected NettyRpcEnv, found = " + 
env)
+    }
+  }
+}
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
index 5ec378060..88ad05c4e 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
@@ -196,7 +196,7 @@ private[celeborn] class Worker(
   }
 
   val pushDataHandler = new PushDataHandler(workerSource)
-  private val pushServer = {
+  private val (pushServerTransportContext, pushServer) = {
     val closeIdleConnections = conf.workerCloseIdleConnections
     val numThreads = 
conf.workerPushIoThreads.getOrElse(storageManager.totalFlusherThread)
     val transportConf =
@@ -210,11 +210,13 @@ private[celeborn] class Worker(
         pushServerLimiter,
         conf.workerPushHeartbeatEnabled,
         workerSource)
-    transportContext.createServer(conf.workerPushPort, 
getServerBootstraps(transportConf))
+    (
+      transportContext,
+      transportContext.createServer(conf.workerPushPort, 
getServerBootstraps(transportConf)))
   }
 
   val replicateHandler = new PushDataHandler(workerSource)
-  val (replicateServer, replicateClientFactory) = {
+  val (replicateTransportContext, replicateServer, replicateClientFactory) = {
     val closeIdleConnections = conf.workerCloseIdleConnections
     val numThreads =
       
conf.workerReplicateIoThreads.getOrElse(storageManager.totalFlusherThread)
@@ -230,12 +232,13 @@ private[celeborn] class Worker(
         false,
         workerSource)
     (
+      transportContext,
       transportContext.createServer(conf.workerReplicatePort),
       transportContext.createClientFactory())
   }
 
   var fetchHandler: FetchHandler = _
-  private val fetchServer = {
+  private val (fetchServerTransportContext, fetchServer) = {
     val closeIdleConnections = conf.workerCloseIdleConnections
     val numThreads = 
conf.workerFetchIoThreads.getOrElse(storageManager.totalFlusherThread)
     val transportConf =
@@ -248,7 +251,9 @@ private[celeborn] class Worker(
         closeIdleConnections,
         conf.workerFetchHeartbeatEnabled,
         workerSource)
-    transportContext.createServer(conf.workerFetchPort, 
getServerBootstraps(transportConf))
+    (
+      transportContext,
+      transportContext.createServer(conf.workerFetchPort, 
getServerBootstraps(transportConf)))
   }
 
   private val pushPort = pushServer.getPort
@@ -551,6 +556,9 @@ private[celeborn] class Worker(
       replicateServer.shutdown(exitKind)
       fetchServer.shutdown(exitKind)
       pushServer.shutdown(exitKind)
+      replicateTransportContext.close()
+      fetchServerTransportContext.close()
+      pushServerTransportContext.close()
       metricsSystem.stop()
       if (conf.internalPortEnabled) {
         internalRpcEnvInUse.stop(internalRpcEndpointRef)
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
index fb083d392..d9494c429 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
@@ -60,6 +60,8 @@ import 
org.apache.celeborn.service.deploy.worker.storage.ChunkStreamManager;
  */
 public class RequestTimeoutIntegrationSuiteJ {
 
+  static final String TEST_MODULE = "shuffle";
+
   private TransportServer server;
   private TransportClientFactory clientFactory;
 
@@ -68,11 +70,14 @@ public class RequestTimeoutIntegrationSuiteJ {
   // A large timeout that "shouldn't happen", for the sake of faulty tests not 
hanging forever.
   private static final int FOREVER = 60 * 1000;
 
+  protected void doSetup(CelebornConf celebornConf) {
+    celebornConf.set("celeborn.shuffle.io.connectionTimeout", "2s");
+    conf = new TransportConf(TEST_MODULE, celebornConf);
+  }
+
   @Before
   public void setUp() throws Exception {
-    CelebornConf _conf = new CelebornConf();
-    _conf.set("celeborn.shuffle.io.connectionTimeout", "2s");
-    conf = new TransportConf("shuffle", _conf);
+    doSetup(new CelebornConf());
   }
 
   @After
@@ -85,6 +90,11 @@ public class RequestTimeoutIntegrationSuiteJ {
     }
   }
 
+  // for validation in subclasses
+  protected TransportConf getConf() {
+    return this.conf;
+  }
+
   // Basic suite: First request completes quickly, and second waits for longer 
than network timeout.
   @Test
   public void timeoutInactiveRequests() throws Exception {
@@ -136,6 +146,7 @@ public class RequestTimeoutIntegrationSuiteJ {
     callback1.latch.await(60, TimeUnit.SECONDS);
     assertNotNull(callback1.failure);
     assertTrue(callback1.failure instanceof IOException);
+    context.close();
 
     semaphore.release();
   }
@@ -195,6 +206,7 @@ public class RequestTimeoutIntegrationSuiteJ {
     callback1.latch.await();
     assertEquals(responseSize, callback1.successLength);
     assertNull(callback1.failure);
+    context.close();
   }
 
   // The timeout is relative to the LAST request sent, which is kinda weird, 
but still.
@@ -265,6 +277,7 @@ public class RequestTimeoutIntegrationSuiteJ {
     callback1.latch.await(60, TimeUnit.SECONDS);
     // failed at same time as previous
     assertTrue(callback1.failure instanceof IOException);
+    context.close();
   }
 
   /**
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/SSLRequestTimeoutIntegrationSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/SSLRequestTimeoutIntegrationSuiteJ.java
new file mode 100644
index 000000000..4037c6f1b
--- /dev/null
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/SSLRequestTimeoutIntegrationSuiteJ.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.network;
+
+import static org.junit.Assert.assertTrue;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.TestHelper;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+
+public class SSLRequestTimeoutIntegrationSuiteJ extends 
RequestTimeoutIntegrationSuiteJ {
+  @Before
+  public void setUp() {
+    // set up SSL for TEST_MODULE
+    doSetup(
+        TestHelper.updateCelebornConfWithMap(
+            new CelebornConf(), 
SslSampleConfigs.createDefaultConfigMapForModule(TEST_MODULE)));
+  }
+
+  @After
+  public void tearDown() {
+    super.tearDown();
+  }
+
+  @Test
+  public void validateSslConfig() {
+    // this is to ensure ssl config has been applied.
+    assertTrue(super.getConf().sslEnabled());
+  }
+}
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
index 36e12ca4c..f4eaf618e 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
@@ -55,10 +55,12 @@ import 
org.apache.celeborn.common.network.util.TransportConf;
 import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
 
 public class ChunkFetchIntegrationSuiteJ {
+  static final String TEST_MODULE = "shuffle";
   static final long STREAM_ID = 1;
   static final int BUFFER_CHUNK_INDEX = 0;
   static final int FILE_CHUNK_INDEX = 1;
 
+  static TransportContext transportContext;
   static TransportServer server;
   static TransportClientFactory clientFactory;
   static ChunkStreamManager chunkStreamManager;
@@ -69,6 +71,10 @@ public class ChunkFetchIntegrationSuiteJ {
 
   @BeforeClass
   public static void setUp() throws Exception {
+    initialize((new CelebornConf()));
+  }
+
+  static void initialize(CelebornConf celebornConf) throws Exception {
     int bufSize = 100_000;
     final ByteBuffer buf = ByteBuffer.allocate(bufSize);
     for (int i = 0; i < bufSize; i++) {
@@ -90,7 +96,7 @@ public class ChunkFetchIntegrationSuiteJ {
       Closeables.close(fp, shouldSuppressIOException);
     }
 
-    final TransportConf conf = new TransportConf("shuffle", new 
CelebornConf());
+    final TransportConf conf = new TransportConf(TEST_MODULE, celebornConf);
     fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, 
testFile.length() - 25);
 
     chunkStreamManager =
@@ -144,9 +150,9 @@ public class ChunkFetchIntegrationSuiteJ {
             return true;
           }
         };
-    TransportContext context = new TransportContext(conf, handler);
-    server = context.createServer();
-    clientFactory = context.createClientFactory();
+    transportContext = new TransportContext(conf, handler);
+    server = transportContext.createServer();
+    clientFactory = transportContext.createClientFactory();
   }
 
   @AfterClass
@@ -154,6 +160,7 @@ public class ChunkFetchIntegrationSuiteJ {
     bufferChunk.release();
     server.close();
     clientFactory.close();
+    transportContext.close();
     testFile.delete();
   }
 
@@ -205,6 +212,11 @@ public class ChunkFetchIntegrationSuiteJ {
     return res;
   }
 
+  // for subclasses to validate
+  TransportConf fetchTransportConf() {
+    return transportContext.getConf();
+  }
+
   @Test
   public void fetchBufferChunk() throws Exception {
     FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX));
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
index 1271d2c45..6417e2aa0 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
@@ -90,14 +90,13 @@ public class ReducePartitionDataWriterSuiteJ {
   private static LocalFlusher localFlusher = null;
   private static WorkerSource source = null;
 
-  private static TransportServer server;
-  private static TransportClientFactory clientFactory;
+  private TransportContext transportContext;
+  private TransportServer server;
+  private TransportClientFactory clientFactory;
   private static long streamId;
   private static int numChunks;
   private final UserIdentifier userIdentifier = new 
UserIdentifier("mock-tenantId", "mock-name");
 
-  private static final TransportConf transConf = new TransportConf("shuffle", 
new CelebornConf());
-
   @BeforeClass
   public static void beforeAll() {
     tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), 
"celeborn");
@@ -138,7 +137,13 @@ public class ReducePartitionDataWriterSuiteJ {
     MemoryManager.initialize(conf);
   }
 
-  public static void setupChunkServer(DiskFileInfo info) throws IOException {
+  protected TransportConf createModuleTransportConf(String module) {
+    return new TransportConf(module, new CelebornConf());
+  }
+
+  public void setupChunkServer(DiskFileInfo info) throws IOException {
+    TransportConf transConf = createModuleTransportConf("shuffle");
+
     FetchHandler handler =
         new FetchHandler(transConf.getCelebornConf(), transConf, 
mock(WorkerSource.class)) {
           @Override
@@ -166,10 +171,10 @@ public class ReducePartitionDataWriterSuiteJ {
         .when(sorter)
         .getSortedFileInfo(anyString(), anyString(), eq(info), anyInt(), 
anyInt());
     handler.setPartitionsSorter(sorter);
-    TransportContext context = new TransportContext(transConf, handler);
-    server = context.createServer();
+    transportContext = new TransportContext(transConf, handler);
+    server = transportContext.createServer();
 
-    clientFactory = context.createClientFactory();
+    clientFactory = transportContext.createClientFactory();
   }
 
   @AfterClass
@@ -184,9 +189,10 @@ public class ReducePartitionDataWriterSuiteJ {
     }
   }
 
-  public static void closeChunkServer() {
+  public void closeChunkServer() {
     server.close();
     clientFactory.close();
+    transportContext.close();
   }
 
   static class FetchResult {
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLChunkFetchIntegrationSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLChunkFetchIntegrationSuiteJ.java
new file mode 100644
index 000000000..e3cc85494
--- /dev/null
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLChunkFetchIntegrationSuiteJ.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.storage;
+
+import static org.junit.Assert.assertTrue;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.TestHelper;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+
+public class SSLChunkFetchIntegrationSuiteJ extends 
ChunkFetchIntegrationSuiteJ {
+
+  @BeforeClass
+  public static void setUp() throws Exception {
+    SSLChunkFetchIntegrationSuiteJ.initialize(
+        TestHelper.updateCelebornConfWithMap(
+            new CelebornConf(), 
SslSampleConfigs.createDefaultConfigMapForModule(TEST_MODULE)));
+  }
+
+  @AfterClass
+  public static void tearDown() {
+    ChunkFetchIntegrationSuiteJ.tearDown();
+  }
+
+  @Test
+  public void validateSslConfig() {
+    // this is to ensure ssl config has been applied.
+    assertTrue(fetchTransportConf().sslEnabled());
+  }
+}
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLReducePartitionDataWriterSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLReducePartitionDataWriterSuiteJ.java
new file mode 100644
index 000000000..e6e37ed51
--- /dev/null
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLReducePartitionDataWriterSuiteJ.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.storage;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.TestHelper;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+import org.apache.celeborn.common.network.util.TransportConf;
+
+public class SSLReducePartitionDataWriterSuiteJ extends 
ReducePartitionDataWriterSuiteJ {
+  protected TransportConf createModuleTransportConf(String module) {
+    CelebornConf conf =
+        TestHelper.updateCelebornConfWithMap(
+            new CelebornConf(), 
SslSampleConfigs.createDefaultConfigMapForModule(module));
+    return new TransportConf(module, conf);
+  }
+
+  @BeforeClass
+  public static void beforeAll() {
+    ReducePartitionDataWriterSuiteJ.beforeAll();
+  }
+
+  @AfterClass
+  public static void afterAll() {
+    ReducePartitionDataWriterSuiteJ.afterAll();
+  }
+}

Reply via email to