This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new a13d16761 [CELEBORN-1401] Add SSL support for ratis communication
a13d16761 is described below
commit a13d16761772f638245a449187b0dc01e0b9ae12
Author: Mridul Muralidharan <mridulatgmail.com>
AuthorDate: Fri May 17 17:08:11 2024 +0800
[CELEBORN-1401] Add SSL support for ratis communication
### What changes were proposed in this pull request?
When SSL is enabled for master, secure the Ratis communication as well with
TLS
### Why are the changes needed?
Currently, when TLS is enabled for RPC, Ratis comms still goes in the clear
- add support for TLS.
Note that currently this only supports GRPC, and not netty.
### Does this PR introduce _any_ user-facing change?
Secures ratis communication when TLS is enabled at master for rpc.
### How was this patch tested?
Local tests and additional unit tests added
Closes #2515 from mridulm/CELEBORN-1401-add-ratis-ssl-support.
Authored-by: Mridul Muralidharan <mridulatgmail.com>
Signed-off-by: Shuang <[email protected]>
---
.../celeborn/common/network/TransportContext.java | 33 +----
.../celeborn/common/network/ssl/SSLFactory.java | 49 ++++++-
.../common/network/ssl/SslSampleConfigs.java | 128 +++++++++++++----
docs/security.md | 2 +
master/pom.xml | 7 +
.../deploy/master/clustermeta/ha/HARaftServer.java | 55 ++++++-
.../master/clustermeta/ha/MasterClusterInfo.scala | 5 +-
.../deploy/master/clustermeta/ha/MasterNode.scala | 11 +-
.../ha/RatisMasterStatusSystemSuiteJ.java | 43 +++---
.../ha/SSLRatisMasterStatusSystemSuiteJ.java | 160 +++++++++++++++++++++
project/CelebornBuild.scala | 1 +
11 files changed, 405 insertions(+), 89 deletions(-)
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
index 625350dc6..869b3fafd 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
@@ -87,7 +87,7 @@ public class TransportContext implements Closeable {
this.conf = conf;
this.msgHandler = msgHandler;
this.closeIdleConnections = closeIdleConnections;
- this.sslFactory = createSslFactory();
+ this.sslFactory = SSLFactory.createSslFactory(conf);
this.channelsLimiter = channelsLimiter;
this.enableHeartbeat = enableHeartbeat;
this.source = source;
@@ -216,37 +216,6 @@ public class TransportContext implements Closeable {
}
}
- private SSLFactory createSslFactory() {
- if (conf.sslEnabled()) {
-
- if (conf.sslEnabledAndKeysAreValid()) {
- return new SSLFactory.Builder()
- .requestedProtocol(conf.sslProtocol())
- .requestedCiphers(conf.sslRequestedCiphers())
- .autoSslEnabled(conf.autoSslEnabled())
- .keyStore(conf.sslKeyStore(), conf.sslKeyStorePassword())
- .trustStore(
- conf.sslTrustStore(),
- conf.sslTrustStorePassword(),
- conf.sslTrustStoreReloadingEnabled(),
- conf.sslTrustStoreReloadIntervalMs())
- .build();
- } else {
- logger.error(
- "SSL encryption enabled but keyStore is not configured for "
- + conf.getModuleName()
- + "! Please ensure the configured keys are present.");
- throw new IllegalArgumentException(
- conf.getModuleName()
- + " SSL encryption enabled for "
- + conf.getModuleName()
- + " but keyStore not configured !");
- }
- } else {
- return null;
- }
- }
-
private TransportChannelHandler createChannelHandler(
Channel channel, BaseMessageHandler msgHandler) {
TransportResponseHandler responseHandler = new
TransportResponseHandler(conf, channel);
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
b/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
index 9612a7749..a0eb106bf 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
@@ -30,6 +30,7 @@ import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@@ -47,6 +48,7 @@ import io.netty.buffer.ByteBufAllocator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.util.JavaUtils;
/**
@@ -96,6 +98,18 @@ public class SSLFactory {
this.jdkSslContext = createSSLContext(requestedProtocol, keyManagers,
trustManagers);
}
+ public List<KeyManager> getKeyManagers() {
+ return null != keyManagers
+ ? Collections.unmodifiableList(Arrays.asList(keyManagers))
+ : Collections.emptyList();
+ }
+
+ public List<TrustManager> getTrustManagers() {
+ return null != trustManagers
+ ? Collections.unmodifiableList(Arrays.asList(trustManagers))
+ : Collections.emptyList();
+ }
+
/*
* As b.trustStore is null, credulousTrustStoreManagers will be used - and
so all
* certs will be accepted - and hence self-signed cert from lifecycle
manager will
@@ -119,7 +133,7 @@ public class SSLFactory {
}
public boolean hasKeyManagers() {
- return null != keyManagers;
+ return null != keyManagers && keyManagers.length > 0;
}
public void destroy() {
@@ -327,7 +341,7 @@ public class SSLFactory {
}
}
- private static TrustManager[] defaultTrustManagers(File trustStore, String
trustStorePassword)
+ public static TrustManager[] defaultTrustManagers(File trustStore, String
trustStorePassword)
throws IOException, KeyStoreException, CertificateException,
NoSuchAlgorithmException {
try (InputStream input = Files.asByteSource(trustStore).openStream()) {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
@@ -436,4 +450,35 @@ public class SSLFactory {
}
return enabled;
}
+
+ public static SSLFactory createSslFactory(TransportConf conf) {
+ if (conf.sslEnabled()) {
+
+ if (conf.sslEnabledAndKeysAreValid()) {
+ return new SSLFactory.Builder()
+ .requestedProtocol(conf.sslProtocol())
+ .requestedCiphers(conf.sslRequestedCiphers())
+ .autoSslEnabled(conf.autoSslEnabled())
+ .keyStore(conf.sslKeyStore(), conf.sslKeyStorePassword())
+ .trustStore(
+ conf.sslTrustStore(),
+ conf.sslTrustStorePassword(),
+ conf.sslTrustStoreReloadingEnabled(),
+ conf.sslTrustStoreReloadIntervalMs())
+ .build();
+ } else {
+ logger.error(
+ "SSL encryption enabled but keyStore is not configured for "
+ + conf.getModuleName()
+ + "! Please ensure the configured keys are present.");
+ throw new IllegalArgumentException(
+ conf.getModuleName()
+ + " SSL encryption enabled for "
+ + conf.getModuleName()
+ + " but keyStore not configured !");
+ }
+ } else {
+ return null;
+ }
+ }
}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
index 93007a936..3b634ae17 100644
---
a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
+++
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
@@ -26,19 +26,34 @@ import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.security.*;
import java.security.cert.Certificate;
-import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
-import java.util.Date;
-import java.util.HashMap;
-import java.util.Map;
-
-import javax.security.auth.x500.X500Principal;
+import java.util.*;
+import java.util.stream.Stream;
import org.apache.commons.io.FileUtils;
-import org.bouncycastle.x509.X509V1CertificateGenerator;
+import org.bouncycastle.asn1.x500.X500Name;
+import org.bouncycastle.asn1.x509.BasicConstraints;
+import org.bouncycastle.asn1.x509.Extension;
+import org.bouncycastle.asn1.x509.GeneralName;
+import org.bouncycastle.asn1.x509.GeneralNames;
+import org.bouncycastle.cert.X509v3CertificateBuilder;
+import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
+import org.bouncycastle.cert.jcajce.JcaX509CertificateHolder;
+import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder;
+import org.bouncycastle.jce.provider.BouncyCastleProvider;
+import org.bouncycastle.operator.ContentSigner;
+import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
public class SslSampleConfigs {
+ private static final Logger LOG =
LoggerFactory.getLogger(SslSampleConfigs.class);
+
+ static {
+ Security.addProvider(new BouncyCastleProvider());
+ }
+
public static final String DEFAULT_KEY_STORE_PATH =
getResourceAsAbsolutePath("/ssl/server.jks");
public static final String SECOND_KEY_STORE_PATH =
getResourceAsAbsolutePath("/ssl/server_another.jks");
@@ -113,34 +128,95 @@ public class SslSampleConfigs {
* @param algorithm the signing algorithm, eg "SHA1withRSA"
* @return the self-signed certificate
*/
+ public static X509Certificate generateCertificate(
+ String dn, KeyPair pair, int days, String algorithm) throws Exception {
+ return generateCertificate(dn, pair, days, algorithm, false, null, null,
null);
+ }
+
+ /**
+ * Create a self-signed X.509 Certificate.
+ *
+ * @param dn the X.509 Distinguished Name, eg "CN=Test, L=London, C=GB"
+ * @param pair the KeyPair for the server
+ * @param days how many days from now the Certificate is valid for
+ * @param algorithm the signing algorithm, eg "SHA1withRSA"
+ * @param generateCaCert Is this request to generate a CA cert
+ * @param altNames Optional: Alternate names to be added to the cert - we
add them as both
+ * hostnames and ip's.
+ * @param caKeyPair Optional: the KeyPair of the CA, to be used to sign this
certificate. caCert
+ * should also be specified to use it
+ * @param caCert Optional: the CA cert, to be used to sign this certificate.
caKeyPair should also
+ * be specified to use it
+ * @return the signed certificate (signed using ca if provided, else
self-signed)
+ */
@SuppressWarnings("deprecation")
public static X509Certificate generateCertificate(
- String dn, KeyPair pair, int days, String algorithm)
- throws CertificateEncodingException, InvalidKeyException,
IllegalStateException,
- NoSuchAlgorithmException, SignatureException {
+ String dn,
+ KeyPair pair,
+ int days,
+ String algorithm,
+ boolean generateCaCert,
+ String[] altNames,
+ KeyPair caKeyPair,
+ X509Certificate caCert)
+ throws Exception {
Date from = new Date();
Date to = new Date(from.getTime() + days * 86400000L);
BigInteger sn = new BigInteger(64, new SecureRandom());
- KeyPair keyPair = pair;
- X509V1CertificateGenerator certGen = new X509V1CertificateGenerator();
- X500Principal dnName = new X500Principal(dn);
-
- certGen.setSerialNumber(sn);
- certGen.setIssuerDN(dnName);
- certGen.setNotBefore(from);
- certGen.setNotAfter(to);
- certGen.setSubjectDN(dnName);
- certGen.setPublicKey(keyPair.getPublic());
- certGen.setSignatureAlgorithm(algorithm);
-
- X509Certificate cert = certGen.generate(pair.getPrivate());
- return cert;
+ X500Name subjectName = new X500Name(dn);
+
+ X500Name issuerName;
+ KeyPair signingKeyPair;
+
+ if (caKeyPair != null && caCert != null) {
+ issuerName = new JcaX509CertificateHolder(caCert).getSubject();
+ signingKeyPair = caKeyPair;
+ } else {
+ issuerName = subjectName;
+ // self signed
+ signingKeyPair = pair;
+ }
+
+ X509v3CertificateBuilder certBuilder =
+ new JcaX509v3CertificateBuilder(
+ issuerName, sn, from, to, new X500Name(dn), pair.getPublic());
+
+ if (null != altNames) {
+ Stream<GeneralName> dnsStream =
+ Arrays.stream(altNames).map(h -> new
GeneralName(GeneralName.dNSName, h));
+ Stream<GeneralName> ipStream =
+ Arrays.stream(altNames)
+ .map(
+ h -> {
+ try {
+ return new GeneralName(GeneralName.iPAddress, h);
+ } catch (Exception ex) {
+ return null;
+ }
+ })
+ .filter(Objects::nonNull);
+
+ GeneralName[] arr = Stream.concat(dnsStream,
ipStream).toArray(GeneralName[]::new);
+ GeneralNames names = new GeneralNames(arr);
+
+ certBuilder.addExtension(Extension.subjectAlternativeName, false, names);
+ LOG.info("Added subjectAlternativeName extension for hosts : " +
Arrays.toString(altNames));
+ }
+
+ if (generateCaCert) {
+ certBuilder.addExtension(Extension.basicConstraints, true, new
BasicConstraints(true));
+ LOG.info("Added CA cert extension");
+ }
+
+ ContentSigner signer =
+ new
JcaContentSignerBuilder(algorithm).build(signingKeyPair.getPrivate());
+ return new
JcaX509CertificateConverter().getCertificate(certBuilder.build(signer));
}
public static KeyPair generateKeyPair(String algorithm) throws
NoSuchAlgorithmException {
KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm);
- keyGen.initialize(1024);
+ keyGen.initialize(4096);
return keyGen.genKeyPair();
}
@@ -178,7 +254,7 @@ public class SslSampleConfigs {
}
private static KeyStore createEmptyKeyStore() throws
GeneralSecurityException, IOException {
- KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
+ KeyStore ks = KeyStore.getInstance("PKCS12");
ks.load(null, null); // initialize
return ks;
}
diff --git a/docs/security.md b/docs/security.md
index 563cfaf7b..8fa7b62fd 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -34,6 +34,8 @@ start="<!--begin-include-->"
end="<!--end-include-->"
!}
+When SSL is enabled for `rpc_service`, Raft communication between masters are
secured **only when** `celeborn.master.ha.ratis.raft.rpc.type` is set to `grpc`.
+
Note that `celeborn.ssl`, **without any module**, can be used to set SSL
default values which applies to all modules.
Also note that `data` module at application side, maps to `push` and `fetch`
at worker - hence, for SSL configuration, worker configuration for `push` and
`fetch` should be compatible with each other and with `data` at application
side.
diff --git a/master/pom.xml b/master/pom.xml
index 4dd3ed11c..0f1fe12bb 100644
--- a/master/pom.xml
+++ b/master/pom.xml
@@ -104,6 +104,13 @@
<artifactId>jersey-test-framework-provider-jetty</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+ <artifactId>celeborn-common_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
diff --git
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java
index 5a5694d86..f0a7560d7 100644
---
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java
+++
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java
@@ -19,6 +19,7 @@ package
org.apache.celeborn.service.deploy.master.clustermeta.ha;
import java.io.File;
import java.io.IOException;
+import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.*;
@@ -27,14 +28,19 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.TrustManager;
+
import scala.Tuple2;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.InvalidProtocolBufferException;
import org.apache.ratis.RaftConfigKeys;
import org.apache.ratis.client.RaftClientConfigKeys;
+import org.apache.ratis.conf.Parameters;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.grpc.GrpcConfigKeys;
+import org.apache.ratis.grpc.GrpcTlsConfig;
import org.apache.ratis.netty.NettyConfigKeys;
import org.apache.ratis.proto.RaftProtos;
import org.apache.ratis.protocol.*;
@@ -54,6 +60,8 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.client.MasterClient;
import org.apache.celeborn.common.exception.CelebornRuntimeException;
+import org.apache.celeborn.common.network.ssl.SSLFactory;
+import org.apache.celeborn.common.protocol.TransportModuleConstants;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.service.deploy.master.clustermeta.ResourceProtos;
@@ -123,13 +131,18 @@ public class HARaftServer {
this.raftGroup = RaftGroup.valueOf(RAFT_GROUP_ID, raftPeers);
this.masterStateMachine = getStateMachine();
this.conf = conf;
- RaftProperties serverProperties = newRaftProperties(conf);
+
+ final RpcType rpc =
SupportedRpcType.valueOfIgnoreCase(conf.haMasterRatisRpcType());
+ RaftProperties serverProperties = newRaftProperties(conf, rpc);
+ Parameters sslParameters =
+ localNode.sslEnabled() ? configureSsl(conf, serverProperties, rpc) :
null;
setDeadlineTime(Integer.MAX_VALUE, Integer.MAX_VALUE); // for default
this.server =
RaftServer.newBuilder()
.setServerId(this.raftPeerId)
.setGroup(this.raftGroup)
.setProperties(serverProperties)
+ .setParameters(sslParameters)
.setStateMachine(masterStateMachine)
.build();
@@ -270,11 +283,9 @@ public class HARaftServer {
}
}
- private RaftProperties newRaftProperties(CelebornConf conf) {
+ private RaftProperties newRaftProperties(CelebornConf conf, RpcType rpc) {
final RaftProperties properties = new RaftProperties();
// Set RPC type
- final String rpcType = conf.haMasterRatisRpcType();
- final RpcType rpc = SupportedRpcType.valueOfIgnoreCase(rpcType);
RaftConfigKeys.Rpc.setType(properties, rpc);
// Set the ratis port number
@@ -375,6 +386,37 @@ public class HARaftServer {
return properties;
}
+ private Parameters configureSsl(CelebornConf conf, RaftProperties
properties, RpcType rpc) {
+
+ if (rpc != SupportedRpcType.GRPC) {
+ LOG.error(
+ "SSL has been disabled for Raft communication between masters. "
+ + "This is only supported when ratis is configured with GRPC");
+ return null;
+ }
+
+ // This is used only for querying state after initialization - not actual
SSL
+ // also why nThreads does not matter
+ SSLFactory factory =
+ SSLFactory.createSslFactory(
+ Utils.fromCelebornConf(conf,
TransportModuleConstants.RPC_SERVICE_MODULE, 1));
+
+ assert (null != factory);
+ assert (factory.hasKeyManagers());
+ assert (!factory.getTrustManagers().isEmpty());
+
+ TrustManager trustManager = factory.getTrustManagers().get(0);
+ KeyManager keyManager = factory.getKeyManagers().get(0);
+
+ Parameters params = new Parameters();
+ GrpcConfigKeys.TLS.setEnabled(properties, true);
+ GrpcConfigKeys.TLS.setConf(params, new GrpcTlsConfig(keyManager,
trustManager, true));
+
+ LOG.info("SSL enabled for ratis communication between masters");
+
+ return params;
+ }
+
private StateMachine getStateMachine() {
StateMachine stateMachine = new StateMachine(this);
stateMachine.setRaftGroupId(RAFT_GROUP_ID);
@@ -536,6 +578,11 @@ public class HARaftServer {
return server.getGroupInfo(groupInfoRequest);
}
+ // Exposed for testing
+ public InetAddress getRaftAddress() {
+ return this.ratisAddr.getAddress();
+ }
+
public int getRaftPort() {
return this.ratisAddr.getPort();
}
diff --git
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala
index 1042f4627..83b6f9283 100644
---
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala
+++
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala
@@ -26,6 +26,7 @@ import scala.util.{Failure, Success, Try}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.CelebornConf._
import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.protocol.TransportModuleConstants
case class MasterClusterInfo(
localNode: MasterNode,
@@ -37,6 +38,8 @@ object MasterClusterInfo extends Logging {
def loadHAConfig(conf: CelebornConf): MasterClusterInfo = {
val localNodeIdOpt = conf.haMasterNodeId
val clusterNodeIds = conf.haMasterNodeIds
+ // If ssl is enabled, we enable it for ratis as well
+ val sslEnabled =
conf.sslEnabled(TransportModuleConstants.RPC_SERVICE_MODULE)
val masterNodes = clusterNodeIds.map { nodeId =>
val ratisHost = conf.haMasterRatisHost(nodeId)
@@ -45,7 +48,7 @@ object MasterClusterInfo extends Logging {
val rpcPort = conf.haMasterNodePort(nodeId)
val internalPort =
if (conf.internalPortEnabled) conf.haMasterNodeInternalPort(nodeId)
else rpcPort
- MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort, internalPort)
+ MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort, internalPort,
sslEnabled)
}
val (localNodes, peerNodes) = localNodeIdOpt match {
diff --git
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala
index 0f2b09ca9..ca4ad8da2 100644
---
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala
+++
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala
@@ -30,7 +30,8 @@ case class MasterNode(
ratisPort: Int,
rpcHost: String,
rpcPort: Int,
- internalRpcPort: Int) {
+ internalRpcPort: Int,
+ sslEnabled: Boolean) {
def isRatisHostUnresolved: Boolean = ratisAddr.isUnresolved
@@ -60,6 +61,7 @@ object MasterNode extends Logging {
private var rpcHost: String = _
private var rpcPort = 0
private var internalRpcPort = 0
+ private var sslEnabled = false
def setNodeId(nodeId: String): this.type = {
this.nodeId = nodeId
@@ -97,8 +99,13 @@ object MasterNode extends Logging {
this
}
+ def setSslEnabled(sslEnabled: Boolean): this.type = {
+ this.sslEnabled = sslEnabled
+ this
+ }
+
def build: MasterNode =
- MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort,
internalRpcPort)
+ MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort,
internalRpcPort, sslEnabled)
}
private def createSocketAddr(host: String, port: Int): InetSocketAddress = {
diff --git
a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
index 8f7307115..340fb8e27 100644
---
a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
+++
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
@@ -65,8 +65,12 @@ public class RatisMasterStatusSystemSuiteJ {
protected static RpcEndpointRef mockRpcEndpoint =
Mockito.mock(RpcEndpointRef.class);
@BeforeClass
- public static void init() throws IOException, InterruptedException {
- resetRaftServer();
+ public static void init() throws Exception {
+ resetRaftServer(
+ configureServerConf(new CelebornConf(), 1),
+ configureServerConf(new CelebornConf(), 2),
+ configureServerConf(new CelebornConf(), 3),
+ false);
}
private static void stopAllRaftServers() {
@@ -81,7 +85,17 @@ public class RatisMasterStatusSystemSuiteJ {
}
}
- public static void resetRaftServer() throws IOException,
InterruptedException {
+ static CelebornConf configureServerConf(CelebornConf conf, int id) throws
IOException {
+ File tmpDir = File.createTempFile("celeborn-ratis" + id, "for-test-only");
+ tmpDir.delete();
+ tmpDir.mkdirs();
+ conf.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(),
tmpDir.getAbsolutePath());
+ return conf;
+ }
+
+ public static void resetRaftServer(
+ CelebornConf conf1, CelebornConf conf2, CelebornConf conf3, boolean
sslEnabled)
+ throws IOException, InterruptedException {
Mockito.when(mockRpcEnv.setupEndpointRef(Mockito.any(), Mockito.any()))
.thenReturn(mockRpcEndpoint);
when(mockRpcEnv.setupEndpointRef(any(), any())).thenReturn(dummyRef);
@@ -101,24 +115,6 @@ public class RatisMasterStatusSystemSuiteJ {
MetaHandler handler2 = new MetaHandler(STATUSSYSTEM2);
MetaHandler handler3 = new MetaHandler(STATUSSYSTEM3);
- CelebornConf conf1 = new CelebornConf();
- File tmpDir1 = File.createTempFile("celeborn-ratis1", "for-test-only");
- tmpDir1.delete();
- tmpDir1.mkdirs();
- conf1.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(),
tmpDir1.getAbsolutePath());
-
- CelebornConf conf2 = new CelebornConf();
- File tmpDir2 = File.createTempFile("celeborn-ratis2", "for-test-only");
- tmpDir2.delete();
- tmpDir2.mkdirs();
- conf2.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(),
tmpDir2.getAbsolutePath());
-
- CelebornConf conf3 = new CelebornConf();
- File tmpDir3 = File.createTempFile("celeborn-ratis3", "for-test-only");
- tmpDir3.delete();
- tmpDir3.mkdirs();
- conf3.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(),
tmpDir3.getAbsolutePath());
-
String id1 = UUID.randomUUID().toString();
String id2 = UUID.randomUUID().toString();
String id3 = UUID.randomUUID().toString();
@@ -133,6 +129,7 @@ public class RatisMasterStatusSystemSuiteJ {
.setRatisPort(ratisPort1)
.setRpcPort(ratisPort1)
.setInternalRpcPort(ratisPort1)
+ .setSslEnabled(sslEnabled)
.setNodeId(id1)
.build();
MasterNode masterNode2 =
@@ -141,6 +138,7 @@ public class RatisMasterStatusSystemSuiteJ {
.setRatisPort(ratisPort2)
.setRpcPort(ratisPort2)
.setInternalRpcPort(ratisPort2)
+ .setSslEnabled(sslEnabled)
.setNodeId(id2)
.build();
MasterNode masterNode3 =
@@ -149,6 +147,7 @@ public class RatisMasterStatusSystemSuiteJ {
.setRatisPort(ratisPort3)
.setRpcPort(ratisPort3)
.setInternalRpcPort(ratisPort3)
+ .setSslEnabled(sslEnabled)
.setNodeId(id3)
.build();
@@ -304,7 +303,7 @@ public class RatisMasterStatusSystemSuiteJ {
} catch (CelebornRuntimeException e) {
Assert.assertTrue(true);
} finally {
- resetRaftServer();
+ init();
}
}
diff --git
a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/SSLRatisMasterStatusSystemSuiteJ.java
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/SSLRatisMasterStatusSystemSuiteJ.java
new file mode 100644
index 000000000..f3d7bca0a
--- /dev/null
+++
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/SSLRatisMasterStatusSystemSuiteJ.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.master.clustermeta.ha;
+
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.Socket;
+import java.security.KeyPair;
+import java.security.cert.X509Certificate;
+import java.util.concurrent.atomic.AtomicReference;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocket;
+import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManager;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.CelebornConf$;
+import org.apache.celeborn.common.network.ssl.SSLFactory;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+import org.apache.celeborn.common.protocol.TransportModuleConstants;
+import org.apache.celeborn.common.util.Utils;
+
+public class SSLRatisMasterStatusSystemSuiteJ extends
RatisMasterStatusSystemSuiteJ {
+
+ private static final CelebornConf confWithHostPreferred = new CelebornConf();
+
+ static {
+ confWithHostPreferred.set(CelebornConf$.MODULE$.NETWORK_BIND_PREFER_IP(),
false);
+ }
+
+ private static class CertificateData {
+ final File file;
+ final KeyPair keyPair;
+ final X509Certificate cert;
+
+ // If caData is null, we are generating for CA - else for a cert which is
using the ca
+ // from caData
+ CertificateData(CertificateData caData) throws Exception {
+ this.file = File.createTempFile("file", ".jks");
+ file.deleteOnExit();
+
+ this.keyPair = SslSampleConfigs.generateKeyPair("RSA");
+
+ // for both ca and cert, we are simply using the same machien as CN
+ String hostname = Utils.localHostName(confWithHostPreferred);
+ final String dn = "CN=" + hostname + ",O=MyCompany,C=US";
+
+ if (null != caData) {
+ this.cert =
+ SslSampleConfigs.generateCertificate(
+ dn,
+ keyPair,
+ 365,
+ "SHA256withRSA",
+ false,
+ new String[] {hostname},
+ caData.keyPair,
+ caData.cert);
+ SslSampleConfigs.createKeyStore(
+ file, "password", "password", "cert", keyPair.getPrivate(), cert);
+ } else {
+ this.cert =
+ SslSampleConfigs.generateCertificate(
+ dn, keyPair, 365, "SHA256withRSA", true, null, null, null);
+ SslSampleConfigs.createTrustStore(file, "password", "ca", cert);
+ }
+ }
+ }
+
+ private static final AtomicReference<CertificateData> caData;
+
+ static {
+ try {
+ caData = new AtomicReference<>(new CertificateData(null));
+ } catch (Exception ex) {
+ throw new IllegalStateException("Unable to initialize", ex);
+ }
+ }
+
+ @BeforeClass
+ public static void init() throws Exception {
+
+ resetRaftServer(
+ configureSsl(caData.get(), configureServerConf(new CelebornConf(), 1)),
+ configureSsl(caData.get(), configureServerConf(new CelebornConf(), 2)),
+ configureSsl(caData.get(), configureServerConf(new CelebornConf(), 3)),
+ true);
+ }
+
+ static CelebornConf configureSsl(CertificateData ca, CelebornConf conf)
throws Exception {
+ conf.set("celeborn.master.ha.ratis.raft.rpc.type", "GRPC");
+
+ CertificateData server = new CertificateData(ca);
+
+ final String module = TransportModuleConstants.RPC_SERVICE_MODULE;
+
+ conf.set("celeborn.ssl." + module + ".enabled", "true");
+ conf.set("celeborn.ssl." + module + ".keyStore",
server.file.getAbsolutePath());
+
+ conf.set("celeborn.ssl." + module + ".keyStorePassword", "password");
+ conf.set("celeborn.ssl." + module + ".keyPassword", "password");
+ conf.set("celeborn.ssl." + module + ".privateKeyPassword", "password");
+ conf.set("celeborn.ssl." + module + ".protocol", "TLSv1.2");
+ conf.set("celeborn.ssl." + module + ".trustStore",
ca.file.getAbsolutePath());
+ conf.set("celeborn.ssl." + module + ".trustStorePassword", "password");
+
+ return conf;
+ }
+
+ @Test
+ public void testSslEnabled() throws Exception {
+ assertTrue(isSslServer(RATISSERVER1.getRaftAddress(),
RATISSERVER1.getRaftPort()));
+ assertTrue(isSslServer(RATISSERVER2.getRaftAddress(),
RATISSERVER2.getRaftPort()));
+ assertTrue(isSslServer(RATISSERVER3.getRaftAddress(),
RATISSERVER3.getRaftPort()));
+ }
+
+ // Validate if the server listening at the port is using TLS or not.
+ static boolean isSslServer(InetAddress address, int port) throws Exception {
+ try (SSLSocket socket = createSslSocket(address, port)) {
+ socket.setSoTimeout(5000);
+ socket.startHandshake();
+ // handshake succeeded, this will always return true in this case
+ return socket.getSession().isValid();
+ }
+ }
+
+ private static SSLSocket createSslSocket(InetAddress address, int port)
throws Exception {
+ TrustManager trustStore =
SSLFactory.defaultTrustManagers(caData.get().file, "password")[0];
+ SSLContext context = SSLContext.getInstance("TLS");
+ context.init(null, new TrustManager[] {trustStore}, null);
+ SSLSocketFactory factory = context.getSocketFactory();
+ Socket socket = new Socket();
+ socket.connect(new InetSocketAddress(address, port), 5000);
+ socket.setSoTimeout(5000);
+
+ return (SSLSocket) factory.createSocket(socket, address.getHostAddress(),
port, true);
+ }
+}
diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala
index 4c889c635..d75c6ba46 100644
--- a/project/CelebornBuild.scala
+++ b/project/CelebornBuild.scala
@@ -528,6 +528,7 @@ object CelebornService {
object CelebornMaster {
lazy val master = Project("celeborn-master", file("master"))
.dependsOn(CelebornCommon.common)
+ .dependsOn(CelebornCommon.common % "test->test;compile->compile")
.dependsOn(CelebornService.service % "test->test;compile->compile")
.settings (
commonSettings,