This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new a13d16761 [CELEBORN-1401] Add SSL support for ratis communication
a13d16761 is described below

commit a13d16761772f638245a449187b0dc01e0b9ae12
Author: Mridul Muralidharan <mridulatgmail.com>
AuthorDate: Fri May 17 17:08:11 2024 +0800

    [CELEBORN-1401] Add SSL support for ratis communication
    
    ### What changes were proposed in this pull request?
    
    When SSL is enabled for master, secure the Ratis communication as well with 
TLS
    
    ### Why are the changes needed?
    
    Currently, when TLS is enabled for RPC, Ratis comms still goes in the clear 
- add support for TLS.
    Note that currently this only supports GRPC, and not netty.
    
    ### Does this PR introduce _any_ user-facing change?
    Secures ratis communication when TLS is enabled at master for rpc.
    
    ### How was this patch tested?
    Local tests and additional unit tests added
    
    Closes #2515 from mridulm/CELEBORN-1401-add-ratis-ssl-support.
    
    Authored-by: Mridul Muralidharan <mridulatgmail.com>
    Signed-off-by: Shuang <[email protected]>
---
 .../celeborn/common/network/TransportContext.java  |  33 +----
 .../celeborn/common/network/ssl/SSLFactory.java    |  49 ++++++-
 .../common/network/ssl/SslSampleConfigs.java       | 128 +++++++++++++----
 docs/security.md                                   |   2 +
 master/pom.xml                                     |   7 +
 .../deploy/master/clustermeta/ha/HARaftServer.java |  55 ++++++-
 .../master/clustermeta/ha/MasterClusterInfo.scala  |   5 +-
 .../deploy/master/clustermeta/ha/MasterNode.scala  |  11 +-
 .../ha/RatisMasterStatusSystemSuiteJ.java          |  43 +++---
 .../ha/SSLRatisMasterStatusSystemSuiteJ.java       | 160 +++++++++++++++++++++
 project/CelebornBuild.scala                        |   1 +
 11 files changed, 405 insertions(+), 89 deletions(-)

diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java 
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
index 625350dc6..869b3fafd 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
@@ -87,7 +87,7 @@ public class TransportContext implements Closeable {
     this.conf = conf;
     this.msgHandler = msgHandler;
     this.closeIdleConnections = closeIdleConnections;
-    this.sslFactory = createSslFactory();
+    this.sslFactory = SSLFactory.createSslFactory(conf);
     this.channelsLimiter = channelsLimiter;
     this.enableHeartbeat = enableHeartbeat;
     this.source = source;
@@ -216,37 +216,6 @@ public class TransportContext implements Closeable {
     }
   }
 
-  private SSLFactory createSslFactory() {
-    if (conf.sslEnabled()) {
-
-      if (conf.sslEnabledAndKeysAreValid()) {
-        return new SSLFactory.Builder()
-            .requestedProtocol(conf.sslProtocol())
-            .requestedCiphers(conf.sslRequestedCiphers())
-            .autoSslEnabled(conf.autoSslEnabled())
-            .keyStore(conf.sslKeyStore(), conf.sslKeyStorePassword())
-            .trustStore(
-                conf.sslTrustStore(),
-                conf.sslTrustStorePassword(),
-                conf.sslTrustStoreReloadingEnabled(),
-                conf.sslTrustStoreReloadIntervalMs())
-            .build();
-      } else {
-        logger.error(
-            "SSL encryption enabled but keyStore is not configured for "
-                + conf.getModuleName()
-                + "! Please ensure the configured keys are present.");
-        throw new IllegalArgumentException(
-            conf.getModuleName()
-                + " SSL encryption enabled for "
-                + conf.getModuleName()
-                + " but keyStore not configured !");
-      }
-    } else {
-      return null;
-    }
-  }
-
   private TransportChannelHandler createChannelHandler(
       Channel channel, BaseMessageHandler msgHandler) {
     TransportResponseHandler responseHandler = new 
TransportResponseHandler(conf, channel);
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java 
b/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
index 9612a7749..a0eb106bf 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
@@ -30,6 +30,7 @@ import java.security.cert.CertificateException;
 import java.security.cert.X509Certificate;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
@@ -47,6 +48,7 @@ import io.netty.buffer.ByteBufAllocator;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.celeborn.common.network.util.TransportConf;
 import org.apache.celeborn.common.util.JavaUtils;
 
 /**
@@ -96,6 +98,18 @@ public class SSLFactory {
     this.jdkSslContext = createSSLContext(requestedProtocol, keyManagers, 
trustManagers);
   }
 
+  public List<KeyManager> getKeyManagers() {
+    return null != keyManagers
+        ? Collections.unmodifiableList(Arrays.asList(keyManagers))
+        : Collections.emptyList();
+  }
+
+  public List<TrustManager> getTrustManagers() {
+    return null != trustManagers
+        ? Collections.unmodifiableList(Arrays.asList(trustManagers))
+        : Collections.emptyList();
+  }
+
   /*
    * As b.trustStore is null, credulousTrustStoreManagers will be used - and 
so all
    * certs will be accepted - and hence self-signed cert from lifecycle 
manager will
@@ -119,7 +133,7 @@ public class SSLFactory {
   }
 
   public boolean hasKeyManagers() {
-    return null != keyManagers;
+    return null != keyManagers && keyManagers.length > 0;
   }
 
   public void destroy() {
@@ -327,7 +341,7 @@ public class SSLFactory {
     }
   }
 
-  private static TrustManager[] defaultTrustManagers(File trustStore, String 
trustStorePassword)
+  public static TrustManager[] defaultTrustManagers(File trustStore, String 
trustStorePassword)
       throws IOException, KeyStoreException, CertificateException, 
NoSuchAlgorithmException {
     try (InputStream input = Files.asByteSource(trustStore).openStream()) {
       KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
@@ -436,4 +450,35 @@ public class SSLFactory {
     }
     return enabled;
   }
+
+  public static SSLFactory createSslFactory(TransportConf conf) {
+    if (conf.sslEnabled()) {
+
+      if (conf.sslEnabledAndKeysAreValid()) {
+        return new SSLFactory.Builder()
+            .requestedProtocol(conf.sslProtocol())
+            .requestedCiphers(conf.sslRequestedCiphers())
+            .autoSslEnabled(conf.autoSslEnabled())
+            .keyStore(conf.sslKeyStore(), conf.sslKeyStorePassword())
+            .trustStore(
+                conf.sslTrustStore(),
+                conf.sslTrustStorePassword(),
+                conf.sslTrustStoreReloadingEnabled(),
+                conf.sslTrustStoreReloadIntervalMs())
+            .build();
+      } else {
+        logger.error(
+            "SSL encryption enabled but keyStore is not configured for "
+                + conf.getModuleName()
+                + "! Please ensure the configured keys are present.");
+        throw new IllegalArgumentException(
+            conf.getModuleName()
+                + " SSL encryption enabled for "
+                + conf.getModuleName()
+                + " but keyStore not configured !");
+      }
+    } else {
+      return null;
+    }
+  }
 }
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
 
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
index 93007a936..3b634ae17 100644
--- 
a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
@@ -26,19 +26,34 @@ import java.nio.file.Files;
 import java.nio.file.StandardCopyOption;
 import java.security.*;
 import java.security.cert.Certificate;
-import java.security.cert.CertificateEncodingException;
 import java.security.cert.X509Certificate;
-import java.util.Date;
-import java.util.HashMap;
-import java.util.Map;
-
-import javax.security.auth.x500.X500Principal;
+import java.util.*;
+import java.util.stream.Stream;
 
 import org.apache.commons.io.FileUtils;
-import org.bouncycastle.x509.X509V1CertificateGenerator;
+import org.bouncycastle.asn1.x500.X500Name;
+import org.bouncycastle.asn1.x509.BasicConstraints;
+import org.bouncycastle.asn1.x509.Extension;
+import org.bouncycastle.asn1.x509.GeneralName;
+import org.bouncycastle.asn1.x509.GeneralNames;
+import org.bouncycastle.cert.X509v3CertificateBuilder;
+import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
+import org.bouncycastle.cert.jcajce.JcaX509CertificateHolder;
+import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder;
+import org.bouncycastle.jce.provider.BouncyCastleProvider;
+import org.bouncycastle.operator.ContentSigner;
+import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 public class SslSampleConfigs {
 
+  private static final Logger LOG = 
LoggerFactory.getLogger(SslSampleConfigs.class);
+
+  static {
+    Security.addProvider(new BouncyCastleProvider());
+  }
+
   public static final String DEFAULT_KEY_STORE_PATH = 
getResourceAsAbsolutePath("/ssl/server.jks");
   public static final String SECOND_KEY_STORE_PATH =
       getResourceAsAbsolutePath("/ssl/server_another.jks");
@@ -113,34 +128,95 @@ public class SslSampleConfigs {
    * @param algorithm the signing algorithm, eg "SHA1withRSA"
    * @return the self-signed certificate
    */
+  public static X509Certificate generateCertificate(
+      String dn, KeyPair pair, int days, String algorithm) throws Exception {
+    return generateCertificate(dn, pair, days, algorithm, false, null, null, 
null);
+  }
+
+  /**
+   * Create a self-signed X.509 Certificate.
+   *
+   * @param dn the X.509 Distinguished Name, eg "CN=Test, L=London, C=GB"
+   * @param pair the KeyPair for the server
+   * @param days how many days from now the Certificate is valid for
+   * @param algorithm the signing algorithm, eg "SHA1withRSA"
+   * @param generateCaCert Is this request to generate a CA cert
+   * @param altNames Optional: Alternate names to be added to the cert - we 
add them as both
+   *     hostnames and ip's.
+   * @param caKeyPair Optional: the KeyPair of the CA, to be used to sign this 
certificate. caCert
+   *     should also be specified to use it
+   * @param caCert Optional: the CA cert, to be used to sign this certificate. 
caKeyPair should also
+   *     be specified to use it
+   * @return the signed certificate (signed using ca if provided, else 
self-signed)
+   */
   @SuppressWarnings("deprecation")
   public static X509Certificate generateCertificate(
-      String dn, KeyPair pair, int days, String algorithm)
-      throws CertificateEncodingException, InvalidKeyException, 
IllegalStateException,
-          NoSuchAlgorithmException, SignatureException {
+      String dn,
+      KeyPair pair,
+      int days,
+      String algorithm,
+      boolean generateCaCert,
+      String[] altNames,
+      KeyPair caKeyPair,
+      X509Certificate caCert)
+      throws Exception {
 
     Date from = new Date();
     Date to = new Date(from.getTime() + days * 86400000L);
     BigInteger sn = new BigInteger(64, new SecureRandom());
-    KeyPair keyPair = pair;
-    X509V1CertificateGenerator certGen = new X509V1CertificateGenerator();
-    X500Principal dnName = new X500Principal(dn);
-
-    certGen.setSerialNumber(sn);
-    certGen.setIssuerDN(dnName);
-    certGen.setNotBefore(from);
-    certGen.setNotAfter(to);
-    certGen.setSubjectDN(dnName);
-    certGen.setPublicKey(keyPair.getPublic());
-    certGen.setSignatureAlgorithm(algorithm);
-
-    X509Certificate cert = certGen.generate(pair.getPrivate());
-    return cert;
+    X500Name subjectName = new X500Name(dn);
+
+    X500Name issuerName;
+    KeyPair signingKeyPair;
+
+    if (caKeyPair != null && caCert != null) {
+      issuerName = new JcaX509CertificateHolder(caCert).getSubject();
+      signingKeyPair = caKeyPair;
+    } else {
+      issuerName = subjectName;
+      // self signed
+      signingKeyPair = pair;
+    }
+
+    X509v3CertificateBuilder certBuilder =
+        new JcaX509v3CertificateBuilder(
+            issuerName, sn, from, to, new X500Name(dn), pair.getPublic());
+
+    if (null != altNames) {
+      Stream<GeneralName> dnsStream =
+          Arrays.stream(altNames).map(h -> new 
GeneralName(GeneralName.dNSName, h));
+      Stream<GeneralName> ipStream =
+          Arrays.stream(altNames)
+              .map(
+                  h -> {
+                    try {
+                      return new GeneralName(GeneralName.iPAddress, h);
+                    } catch (Exception ex) {
+                      return null;
+                    }
+                  })
+              .filter(Objects::nonNull);
+
+      GeneralName[] arr = Stream.concat(dnsStream, 
ipStream).toArray(GeneralName[]::new);
+      GeneralNames names = new GeneralNames(arr);
+
+      certBuilder.addExtension(Extension.subjectAlternativeName, false, names);
+      LOG.info("Added subjectAlternativeName extension for hosts : " + 
Arrays.toString(altNames));
+    }
+
+    if (generateCaCert) {
+      certBuilder.addExtension(Extension.basicConstraints, true, new 
BasicConstraints(true));
+      LOG.info("Added CA cert extension");
+    }
+
+    ContentSigner signer =
+        new 
JcaContentSignerBuilder(algorithm).build(signingKeyPair.getPrivate());
+    return new 
JcaX509CertificateConverter().getCertificate(certBuilder.build(signer));
   }
 
   public static KeyPair generateKeyPair(String algorithm) throws 
NoSuchAlgorithmException {
     KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm);
-    keyGen.initialize(1024);
+    keyGen.initialize(4096);
     return keyGen.genKeyPair();
   }
 
@@ -178,7 +254,7 @@ public class SslSampleConfigs {
   }
 
   private static KeyStore createEmptyKeyStore() throws 
GeneralSecurityException, IOException {
-    KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
+    KeyStore ks = KeyStore.getInstance("PKCS12");
     ks.load(null, null); // initialize
     return ks;
   }
diff --git a/docs/security.md b/docs/security.md
index 563cfaf7b..8fa7b62fd 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -34,6 +34,8 @@ start="<!--begin-include-->"
 end="<!--end-include-->"
 !}
 
+When SSL is enabled for `rpc_service`, Raft communication between masters are 
secured **only when** `celeborn.master.ha.ratis.raft.rpc.type` is set to `grpc`.
+
 Note that `celeborn.ssl`, **without any module**, can be used to set SSL 
default values which applies to all modules.
 
 Also note that `data` module at application side, maps to `push` and `fetch` 
at worker - hence, for SSL configuration, worker configuration for `push` and 
`fetch` should be compatible with each other and with `data` at application 
side.
diff --git a/master/pom.xml b/master/pom.xml
index 4dd3ed11c..0f1fe12bb 100644
--- a/master/pom.xml
+++ b/master/pom.xml
@@ -104,6 +104,13 @@
       <artifactId>jersey-test-framework-provider-jetty</artifactId>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.apache.celeborn</groupId>
+      <artifactId>celeborn-common_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
 
   <build>
diff --git 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java
 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java
index 5a5694d86..f0a7560d7 100644
--- 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java
+++ 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java
@@ -19,6 +19,7 @@ package 
org.apache.celeborn.service.deploy.master.clustermeta.ha;
 
 import java.io.File;
 import java.io.IOException;
+import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.nio.charset.StandardCharsets;
 import java.util.*;
@@ -27,14 +28,19 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
 
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.TrustManager;
+
 import scala.Tuple2;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.protobuf.InvalidProtocolBufferException;
 import org.apache.ratis.RaftConfigKeys;
 import org.apache.ratis.client.RaftClientConfigKeys;
+import org.apache.ratis.conf.Parameters;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.grpc.GrpcConfigKeys;
+import org.apache.ratis.grpc.GrpcTlsConfig;
 import org.apache.ratis.netty.NettyConfigKeys;
 import org.apache.ratis.proto.RaftProtos;
 import org.apache.ratis.protocol.*;
@@ -54,6 +60,8 @@ import org.slf4j.LoggerFactory;
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.client.MasterClient;
 import org.apache.celeborn.common.exception.CelebornRuntimeException;
+import org.apache.celeborn.common.network.ssl.SSLFactory;
+import org.apache.celeborn.common.protocol.TransportModuleConstants;
 import org.apache.celeborn.common.util.ThreadUtils;
 import org.apache.celeborn.common.util.Utils;
 import org.apache.celeborn.service.deploy.master.clustermeta.ResourceProtos;
@@ -123,13 +131,18 @@ public class HARaftServer {
     this.raftGroup = RaftGroup.valueOf(RAFT_GROUP_ID, raftPeers);
     this.masterStateMachine = getStateMachine();
     this.conf = conf;
-    RaftProperties serverProperties = newRaftProperties(conf);
+
+    final RpcType rpc = 
SupportedRpcType.valueOfIgnoreCase(conf.haMasterRatisRpcType());
+    RaftProperties serverProperties = newRaftProperties(conf, rpc);
+    Parameters sslParameters =
+        localNode.sslEnabled() ? configureSsl(conf, serverProperties, rpc) : 
null;
     setDeadlineTime(Integer.MAX_VALUE, Integer.MAX_VALUE); // for default
     this.server =
         RaftServer.newBuilder()
             .setServerId(this.raftPeerId)
             .setGroup(this.raftGroup)
             .setProperties(serverProperties)
+            .setParameters(sslParameters)
             .setStateMachine(masterStateMachine)
             .build();
 
@@ -270,11 +283,9 @@ public class HARaftServer {
     }
   }
 
-  private RaftProperties newRaftProperties(CelebornConf conf) {
+  private RaftProperties newRaftProperties(CelebornConf conf, RpcType rpc) {
     final RaftProperties properties = new RaftProperties();
     // Set RPC type
-    final String rpcType = conf.haMasterRatisRpcType();
-    final RpcType rpc = SupportedRpcType.valueOfIgnoreCase(rpcType);
     RaftConfigKeys.Rpc.setType(properties, rpc);
 
     // Set the ratis port number
@@ -375,6 +386,37 @@ public class HARaftServer {
     return properties;
   }
 
+  private Parameters configureSsl(CelebornConf conf, RaftProperties 
properties, RpcType rpc) {
+
+    if (rpc != SupportedRpcType.GRPC) {
+      LOG.error(
+          "SSL has been disabled for Raft communication between masters. "
+              + "This is only supported when ratis is configured with GRPC");
+      return null;
+    }
+
+    // This is used only for querying state after initialization - not actual 
SSL
+    // also why nThreads does not matter
+    SSLFactory factory =
+        SSLFactory.createSslFactory(
+            Utils.fromCelebornConf(conf, 
TransportModuleConstants.RPC_SERVICE_MODULE, 1));
+
+    assert (null != factory);
+    assert (factory.hasKeyManagers());
+    assert (!factory.getTrustManagers().isEmpty());
+
+    TrustManager trustManager = factory.getTrustManagers().get(0);
+    KeyManager keyManager = factory.getKeyManagers().get(0);
+
+    Parameters params = new Parameters();
+    GrpcConfigKeys.TLS.setEnabled(properties, true);
+    GrpcConfigKeys.TLS.setConf(params, new GrpcTlsConfig(keyManager, 
trustManager, true));
+
+    LOG.info("SSL enabled for ratis communication between masters");
+
+    return params;
+  }
+
   private StateMachine getStateMachine() {
     StateMachine stateMachine = new StateMachine(this);
     stateMachine.setRaftGroupId(RAFT_GROUP_ID);
@@ -536,6 +578,11 @@ public class HARaftServer {
     return server.getGroupInfo(groupInfoRequest);
   }
 
+  // Exposed for testing
+  public InetAddress getRaftAddress() {
+    return this.ratisAddr.getAddress();
+  }
+
   public int getRaftPort() {
     return this.ratisAddr.getPort();
   }
diff --git 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala
 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala
index 1042f4627..83b6f9283 100644
--- 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala
+++ 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala
@@ -26,6 +26,7 @@ import scala.util.{Failure, Success, Try}
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.CelebornConf._
 import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.protocol.TransportModuleConstants
 
 case class MasterClusterInfo(
     localNode: MasterNode,
@@ -37,6 +38,8 @@ object MasterClusterInfo extends Logging {
   def loadHAConfig(conf: CelebornConf): MasterClusterInfo = {
     val localNodeIdOpt = conf.haMasterNodeId
     val clusterNodeIds = conf.haMasterNodeIds
+    // If ssl is enabled, we enable it for ratis as well
+    val sslEnabled = 
conf.sslEnabled(TransportModuleConstants.RPC_SERVICE_MODULE)
 
     val masterNodes = clusterNodeIds.map { nodeId =>
       val ratisHost = conf.haMasterRatisHost(nodeId)
@@ -45,7 +48,7 @@ object MasterClusterInfo extends Logging {
       val rpcPort = conf.haMasterNodePort(nodeId)
       val internalPort =
         if (conf.internalPortEnabled) conf.haMasterNodeInternalPort(nodeId) 
else rpcPort
-      MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort, internalPort)
+      MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort, internalPort, 
sslEnabled)
     }
 
     val (localNodes, peerNodes) = localNodeIdOpt match {
diff --git 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala
 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala
index 0f2b09ca9..ca4ad8da2 100644
--- 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala
+++ 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala
@@ -30,7 +30,8 @@ case class MasterNode(
     ratisPort: Int,
     rpcHost: String,
     rpcPort: Int,
-    internalRpcPort: Int) {
+    internalRpcPort: Int,
+    sslEnabled: Boolean) {
 
   def isRatisHostUnresolved: Boolean = ratisAddr.isUnresolved
 
@@ -60,6 +61,7 @@ object MasterNode extends Logging {
     private var rpcHost: String = _
     private var rpcPort = 0
     private var internalRpcPort = 0
+    private var sslEnabled = false
 
     def setNodeId(nodeId: String): this.type = {
       this.nodeId = nodeId
@@ -97,8 +99,13 @@ object MasterNode extends Logging {
       this
     }
 
+    def setSslEnabled(sslEnabled: Boolean): this.type = {
+      this.sslEnabled = sslEnabled
+      this
+    }
+
     def build: MasterNode =
-      MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort, 
internalRpcPort)
+      MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort, 
internalRpcPort, sslEnabled)
   }
 
   private def createSocketAddr(host: String, port: Int): InetSocketAddress = {
diff --git 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
index 8f7307115..340fb8e27 100644
--- 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
+++ 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
@@ -65,8 +65,12 @@ public class RatisMasterStatusSystemSuiteJ {
   protected static RpcEndpointRef mockRpcEndpoint = 
Mockito.mock(RpcEndpointRef.class);
 
   @BeforeClass
-  public static void init() throws IOException, InterruptedException {
-    resetRaftServer();
+  public static void init() throws Exception {
+    resetRaftServer(
+        configureServerConf(new CelebornConf(), 1),
+        configureServerConf(new CelebornConf(), 2),
+        configureServerConf(new CelebornConf(), 3),
+        false);
   }
 
   private static void stopAllRaftServers() {
@@ -81,7 +85,17 @@ public class RatisMasterStatusSystemSuiteJ {
     }
   }
 
-  public static void resetRaftServer() throws IOException, 
InterruptedException {
+  static CelebornConf configureServerConf(CelebornConf conf, int id) throws 
IOException {
+    File tmpDir = File.createTempFile("celeborn-ratis" + id, "for-test-only");
+    tmpDir.delete();
+    tmpDir.mkdirs();
+    conf.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(), 
tmpDir.getAbsolutePath());
+    return conf;
+  }
+
+  public static void resetRaftServer(
+      CelebornConf conf1, CelebornConf conf2, CelebornConf conf3, boolean 
sslEnabled)
+      throws IOException, InterruptedException {
     Mockito.when(mockRpcEnv.setupEndpointRef(Mockito.any(), Mockito.any()))
         .thenReturn(mockRpcEndpoint);
     when(mockRpcEnv.setupEndpointRef(any(), any())).thenReturn(dummyRef);
@@ -101,24 +115,6 @@ public class RatisMasterStatusSystemSuiteJ {
         MetaHandler handler2 = new MetaHandler(STATUSSYSTEM2);
         MetaHandler handler3 = new MetaHandler(STATUSSYSTEM3);
 
-        CelebornConf conf1 = new CelebornConf();
-        File tmpDir1 = File.createTempFile("celeborn-ratis1", "for-test-only");
-        tmpDir1.delete();
-        tmpDir1.mkdirs();
-        conf1.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(), 
tmpDir1.getAbsolutePath());
-
-        CelebornConf conf2 = new CelebornConf();
-        File tmpDir2 = File.createTempFile("celeborn-ratis2", "for-test-only");
-        tmpDir2.delete();
-        tmpDir2.mkdirs();
-        conf2.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(), 
tmpDir2.getAbsolutePath());
-
-        CelebornConf conf3 = new CelebornConf();
-        File tmpDir3 = File.createTempFile("celeborn-ratis3", "for-test-only");
-        tmpDir3.delete();
-        tmpDir3.mkdirs();
-        conf3.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(), 
tmpDir3.getAbsolutePath());
-
         String id1 = UUID.randomUUID().toString();
         String id2 = UUID.randomUUID().toString();
         String id3 = UUID.randomUUID().toString();
@@ -133,6 +129,7 @@ public class RatisMasterStatusSystemSuiteJ {
                 .setRatisPort(ratisPort1)
                 .setRpcPort(ratisPort1)
                 .setInternalRpcPort(ratisPort1)
+                .setSslEnabled(sslEnabled)
                 .setNodeId(id1)
                 .build();
         MasterNode masterNode2 =
@@ -141,6 +138,7 @@ public class RatisMasterStatusSystemSuiteJ {
                 .setRatisPort(ratisPort2)
                 .setRpcPort(ratisPort2)
                 .setInternalRpcPort(ratisPort2)
+                .setSslEnabled(sslEnabled)
                 .setNodeId(id2)
                 .build();
         MasterNode masterNode3 =
@@ -149,6 +147,7 @@ public class RatisMasterStatusSystemSuiteJ {
                 .setRatisPort(ratisPort3)
                 .setRpcPort(ratisPort3)
                 .setInternalRpcPort(ratisPort3)
+                .setSslEnabled(sslEnabled)
                 .setNodeId(id3)
                 .build();
 
@@ -304,7 +303,7 @@ public class RatisMasterStatusSystemSuiteJ {
     } catch (CelebornRuntimeException e) {
       Assert.assertTrue(true);
     } finally {
-      resetRaftServer();
+      init();
     }
   }
 
diff --git 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/SSLRatisMasterStatusSystemSuiteJ.java
 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/SSLRatisMasterStatusSystemSuiteJ.java
new file mode 100644
index 000000000..f3d7bca0a
--- /dev/null
+++ 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/SSLRatisMasterStatusSystemSuiteJ.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.master.clustermeta.ha;
+
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.Socket;
+import java.security.KeyPair;
+import java.security.cert.X509Certificate;
+import java.util.concurrent.atomic.AtomicReference;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocket;
+import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManager;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.CelebornConf$;
+import org.apache.celeborn.common.network.ssl.SSLFactory;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+import org.apache.celeborn.common.protocol.TransportModuleConstants;
+import org.apache.celeborn.common.util.Utils;
+
+public class SSLRatisMasterStatusSystemSuiteJ extends 
RatisMasterStatusSystemSuiteJ {
+
+  private static final CelebornConf confWithHostPreferred = new CelebornConf();
+
+  static {
+    confWithHostPreferred.set(CelebornConf$.MODULE$.NETWORK_BIND_PREFER_IP(), 
false);
+  }
+
+  private static class CertificateData {
+    final File file;
+    final KeyPair keyPair;
+    final X509Certificate cert;
+
+    // If caData is null, we are generating for CA - else for a cert which is 
using the ca
+    // from caData
+    CertificateData(CertificateData caData) throws Exception {
+      this.file = File.createTempFile("file", ".jks");
+      file.deleteOnExit();
+
+      this.keyPair = SslSampleConfigs.generateKeyPair("RSA");
+
+      // for both ca and cert, we are simply using the same machien as CN
+      String hostname = Utils.localHostName(confWithHostPreferred);
+      final String dn = "CN=" + hostname + ",O=MyCompany,C=US";
+
+      if (null != caData) {
+        this.cert =
+            SslSampleConfigs.generateCertificate(
+                dn,
+                keyPair,
+                365,
+                "SHA256withRSA",
+                false,
+                new String[] {hostname},
+                caData.keyPair,
+                caData.cert);
+        SslSampleConfigs.createKeyStore(
+            file, "password", "password", "cert", keyPair.getPrivate(), cert);
+      } else {
+        this.cert =
+            SslSampleConfigs.generateCertificate(
+                dn, keyPair, 365, "SHA256withRSA", true, null, null, null);
+        SslSampleConfigs.createTrustStore(file, "password", "ca", cert);
+      }
+    }
+  }
+
+  private static final AtomicReference<CertificateData> caData;
+
+  static {
+    try {
+      caData = new AtomicReference<>(new CertificateData(null));
+    } catch (Exception ex) {
+      throw new IllegalStateException("Unable to initialize", ex);
+    }
+  }
+
+  @BeforeClass
+  public static void init() throws Exception {
+
+    resetRaftServer(
+        configureSsl(caData.get(), configureServerConf(new CelebornConf(), 1)),
+        configureSsl(caData.get(), configureServerConf(new CelebornConf(), 2)),
+        configureSsl(caData.get(), configureServerConf(new CelebornConf(), 3)),
+        true);
+  }
+
+  static CelebornConf configureSsl(CertificateData ca, CelebornConf conf) 
throws Exception {
+    conf.set("celeborn.master.ha.ratis.raft.rpc.type", "GRPC");
+
+    CertificateData server = new CertificateData(ca);
+
+    final String module = TransportModuleConstants.RPC_SERVICE_MODULE;
+
+    conf.set("celeborn.ssl." + module + ".enabled", "true");
+    conf.set("celeborn.ssl." + module + ".keyStore", 
server.file.getAbsolutePath());
+
+    conf.set("celeborn.ssl." + module + ".keyStorePassword", "password");
+    conf.set("celeborn.ssl." + module + ".keyPassword", "password");
+    conf.set("celeborn.ssl." + module + ".privateKeyPassword", "password");
+    conf.set("celeborn.ssl." + module + ".protocol", "TLSv1.2");
+    conf.set("celeborn.ssl." + module + ".trustStore", 
ca.file.getAbsolutePath());
+    conf.set("celeborn.ssl." + module + ".trustStorePassword", "password");
+
+    return conf;
+  }
+
+  @Test
+  public void testSslEnabled() throws Exception {
+    assertTrue(isSslServer(RATISSERVER1.getRaftAddress(), 
RATISSERVER1.getRaftPort()));
+    assertTrue(isSslServer(RATISSERVER2.getRaftAddress(), 
RATISSERVER2.getRaftPort()));
+    assertTrue(isSslServer(RATISSERVER3.getRaftAddress(), 
RATISSERVER3.getRaftPort()));
+  }
+
+  // Validate if the server listening at the port is using TLS or not.
+  static boolean isSslServer(InetAddress address, int port) throws Exception {
+    try (SSLSocket socket = createSslSocket(address, port)) {
+      socket.setSoTimeout(5000);
+      socket.startHandshake();
+      // handshake succeeded, this will always return true in this case
+      return socket.getSession().isValid();
+    }
+  }
+
+  private static SSLSocket createSslSocket(InetAddress address, int port) 
throws Exception {
+    TrustManager trustStore = 
SSLFactory.defaultTrustManagers(caData.get().file, "password")[0];
+    SSLContext context = SSLContext.getInstance("TLS");
+    context.init(null, new TrustManager[] {trustStore}, null);
+    SSLSocketFactory factory = context.getSocketFactory();
+    Socket socket = new Socket();
+    socket.connect(new InetSocketAddress(address, port), 5000);
+    socket.setSoTimeout(5000);
+
+    return (SSLSocket) factory.createSocket(socket, address.getHostAddress(), 
port, true);
+  }
+}
diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala
index 4c889c635..d75c6ba46 100644
--- a/project/CelebornBuild.scala
+++ b/project/CelebornBuild.scala
@@ -528,6 +528,7 @@ object CelebornService {
 object CelebornMaster {
   lazy val master = Project("celeborn-master", file("master"))
     .dependsOn(CelebornCommon.common)
+    .dependsOn(CelebornCommon.common % "test->test;compile->compile")
     .dependsOn(CelebornService.service % "test->test;compile->compile")
     .settings (
       commonSettings,


Reply via email to