This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 729d3dce2 RATIS-1747. Support keyManager and trustManager in 
tlsConfig. (#785)
729d3dce2 is described below

commit 729d3dce220f352a720cbe38d0762399b6e55f06
Author: Sammi Chen <[email protected]>
AuthorDate: Sat Nov 19 04:08:44 2022 +0800

    RATIS-1747. Support keyManager and trustManager in tlsConfig. (#785)
---
 pom.xml                                            |   2 +
 .../java/org/apache/ratis/security/TlsConf.java    |  47 ++++-
 .../apache/ratis/security/SecurityTestUtils.java   |  65 -------
 .../java/org/apache/ratis/grpc/GrpcTlsConfig.java  |  10 ++
 .../main/java/org/apache/ratis/grpc/GrpcUtil.java  |  58 ++++++
 .../grpc/client/GrpcClientProtocolClient.java      |  14 +-
 .../grpc/server/GrpcServerProtocolClient.java      |  14 +-
 .../org/apache/ratis/grpc/server/GrpcService.java  |  14 +-
 .../java/org/apache/ratis/netty/NettyUtils.java    |  16 ++
 ratis-test/pom.xml                                 |  13 ++
 .../apache/ratis/grpc/TestRaftServerWithGrpc.java  |  40 ++++-
 .../apache/ratis/security/SecurityTestUtils.java   | 198 +++++++++++++++++++++
 12 files changed, 388 insertions(+), 103 deletions(-)

diff --git a/pom.xml b/pom.xml
index c8b3f4607..21641b4db 100644
--- a/pom.xml
+++ b/pom.xml
@@ -218,6 +218,8 @@
     <test.exclude.pattern>_</test.exclude.pattern>
     <!-- number of threads/forks to use when running tests in parallel, see 
parallel-tests profile -->
     <testsThreadCount>4</testsThreadCount>
+
+    <bouncycastle.version>1.70</bouncycastle.version>
   </properties>
 
   <dependencyManagement>
diff --git a/ratis-common/src/main/java/org/apache/ratis/security/TlsConf.java 
b/ratis-common/src/main/java/org/apache/ratis/security/TlsConf.java
index c9f2e5ebf..30cf67c83 100644
--- a/ratis-common/src/main/java/org/apache/ratis/security/TlsConf.java
+++ b/ratis-common/src/main/java/org/apache/ratis/security/TlsConf.java
@@ -20,6 +20,8 @@ package org.apache.ratis.security;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.Preconditions;
 
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.TrustManager;
 import java.io.File;
 import java.security.PrivateKey;
 import java.security.cert.X509Certificate;
@@ -97,15 +99,26 @@ public class TlsConf {
   public static final class TrustManagerConf {
     /** Trust certificates. */
     private final CertificatesConf trustCertificates;
+    private final TrustManager trustManager;
 
     private TrustManagerConf(CertificatesConf trustCertificates) {
       this.trustCertificates = trustCertificates;
+      this.trustManager = null;
+    }
+
+    private TrustManagerConf(TrustManager trustManager) {
+      this.trustManager = trustManager;
+      this.trustCertificates = null;
     }
 
     /** @return the trust certificates. */
     public CertificatesConf getTrustCertificates() {
       return trustCertificates;
     }
+
+    public TrustManager getTrustManager() {
+      return trustManager;
+    }
   }
 
   /** Configurations for a key manager. */
@@ -114,6 +127,7 @@ public class TlsConf {
     private final PrivateKeyConf privateKey;
     /** Certificates for the private key. */
     private final CertificatesConf keyCertificates;
+    private final KeyManager keyManager;
 
     private KeyManagerConf(PrivateKeyConf privateKey, CertificatesConf 
keyCertificates) {
       this.privateKey = Objects.requireNonNull(privateKey, "privateKey == 
null");
@@ -122,6 +136,13 @@ public class TlsConf {
           () -> "The privateKey (isFileBased? " + privateKey.isFileBased()
               + ") and the keyCertificates (isFileBased? " + 
keyCertificates.isFileBased()
               + ") must be either both file based or both not.");
+      keyManager = null;
+    }
+
+    private KeyManagerConf(KeyManager keyManager) {
+      this.keyManager = keyManager;
+      this.privateKey = null;
+      this.keyCertificates = null;
     }
 
     /** @return the private key. */
@@ -137,6 +158,10 @@ public class TlsConf {
     public boolean isFileBased() {
       return privateKey.isFileBased();
     }
+
+    public KeyManager getKeyManager() {
+      return keyManager;
+    }
   }
 
   private static final AtomicInteger COUNT = new AtomicInteger();
@@ -188,6 +213,8 @@ public class TlsConf {
     private PrivateKeyConf privateKey;
     private CertificatesConf keyCertificates;
     private boolean mutualTls;
+    private KeyManager keyManager;
+    private TrustManager trustManager;
 
     public Builder setName(String name) {
       this.name = name;
@@ -209,6 +236,16 @@ public class TlsConf {
       return this;
     }
 
+    public Builder setKeyManager(KeyManager keyManager) {
+      this.keyManager = keyManager;
+      return this;
+    }
+
+    public Builder setTrustManager(TrustManager trustManager) {
+      this.trustManager = trustManager;
+      return this;
+    }
+
     public Builder setMutualTls(boolean mutualTls) {
       this.mutualTls = mutualTls;
       return this;
@@ -223,11 +260,17 @@ public class TlsConf {
     }
 
     private TrustManagerConf buildTrustManagerConf() {
-      return new TrustManagerConf(trustCertificates);
+      if (trustManager != null) {
+        return new TrustManagerConf(trustManager);
+      } else {
+        return new TrustManagerConf(trustCertificates);
+      }
     }
 
     private KeyManagerConf buildKeyManagerConf() {
-      if (privateKey == null && keyCertificates == null) {
+      if (keyManager != null) {
+        return new KeyManagerConf(keyManager);
+      } else if (privateKey == null && keyCertificates == null) {
         return null;
       } else if (privateKey != null && keyCertificates != null) {
         return new KeyManagerConf(privateKey, keyCertificates);
diff --git 
a/ratis-common/src/test/java/org/apache/ratis/security/SecurityTestUtils.java 
b/ratis-common/src/test/java/org/apache/ratis/security/SecurityTestUtils.java
deleted file mode 100644
index d35a38be0..000000000
--- 
a/ratis-common/src/test/java/org/apache/ratis/security/SecurityTestUtils.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.ratis.security;
-
-import org.apache.ratis.security.TlsConf.Builder;
-import org.apache.ratis.security.TlsConf.CertificatesConf;
-import org.apache.ratis.security.TlsConf.PrivateKeyConf;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.File;
-import java.net.URL;
-import java.util.Optional;
-
-public interface SecurityTestUtils {
-  Logger LOG = LoggerFactory.getLogger(SecurityTestUtils.class);
-
-  ClassLoader CLASS_LOADER = SecurityTestUtils.class.getClassLoader();
-
-  static File getResource(String name) {
-    final File file = Optional.ofNullable(CLASS_LOADER.getResource(name))
-        .map(URL::getFile)
-        .map(File::new)
-        .orElse(null);
-    LOG.info("Getting resource {}: {}", name, file);
-    return file;
-  }
-
-  static TlsConf newServerTlsConfig(boolean mutualAuthn) {
-    LOG.info("newServerTlsConfig: mutualAuthn? {}", mutualAuthn);
-    return new Builder()
-        .setName("server")
-        .setPrivateKey(new PrivateKeyConf(getResource("ssl/server.pem")))
-        .setKeyCertificates(new 
CertificatesConf(getResource("ssl/server.crt")))
-        .setTrustCertificates(new 
CertificatesConf(getResource("ssl/client.crt")))
-        .setMutualTls(mutualAuthn)
-        .build();
-  }
-
-  static TlsConf newClientTlsConfig(boolean mutualAuthn) {
-    LOG.info("newClientTlsConfig: mutualAuthn? {}", mutualAuthn);
-    return new Builder()
-        .setName("client")
-        .setPrivateKey(new PrivateKeyConf(getResource("ssl/client.pem")))
-        .setKeyCertificates(new 
CertificatesConf(getResource("ssl/client.crt")))
-        .setTrustCertificates(new CertificatesConf(getResource("ssl/ca.crt")))
-        .setMutualTls(mutualAuthn)
-        .build();
-  }
-}
\ No newline at end of file
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcTlsConfig.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcTlsConfig.java
index 13176e503..ff540c3cc 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcTlsConfig.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcTlsConfig.java
@@ -19,6 +19,8 @@ package org.apache.ratis.grpc;
 
 import org.apache.ratis.security.TlsConf;
 
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.TrustManager;
 import java.io.File;
 import java.security.PrivateKey;
 import java.security.cert.X509Certificate;
@@ -105,6 +107,10 @@ public class GrpcTlsConfig extends TlsConf {
     this.fileBasedConfig = fileBasedConfig;
   }
 
+  public GrpcTlsConfig(KeyManager keyManager, TrustManager trustManager, 
boolean mTlsEnabled) {
+    this(newBuilder(keyManager, trustManager, mTlsEnabled), false);
+  }
+
   private static Builder newBuilder(PrivateKey privateKey, X509Certificate 
certChain,
       List<X509Certificate> trustStore, boolean mTlsEnabled) {
     final Builder b = newBuilder().setMutualTls(mTlsEnabled);
@@ -121,4 +127,8 @@ public class GrpcTlsConfig extends TlsConf {
     
Optional.ofNullable(certChainFile).map(CertificatesConf::new).ifPresent(b::setKeyCertificates);
     return b;
   }
+
+  private static Builder newBuilder(KeyManager keyManager, TrustManager 
trustManager, boolean mTlsEnabled) {
+    return 
newBuilder().setMutualTls(mTlsEnabled).setKeyManager(keyManager).setTrustManager(trustManager);
+  }
 }
\ No newline at end of file
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
index 6f50e150c..57673c991 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
@@ -20,11 +20,16 @@ package org.apache.ratis.grpc;
 import org.apache.ratis.protocol.RaftClientReply;
 import org.apache.ratis.protocol.exceptions.ServerNotReadyException;
 import org.apache.ratis.protocol.exceptions.TimeoutIOException;
+import org.apache.ratis.security.TlsConf.TrustManagerConf;
+import org.apache.ratis.security.TlsConf.CertificatesConf;
+import org.apache.ratis.security.TlsConf.PrivateKeyConf;
+import org.apache.ratis.security.TlsConf.KeyManagerConf;
 import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
 import org.apache.ratis.thirdparty.io.grpc.Metadata;
 import org.apache.ratis.thirdparty.io.grpc.Status;
 import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException;
 import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
 import org.apache.ratis.util.IOUtils;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.LogUtils;
@@ -33,6 +38,8 @@ import org.apache.ratis.util.function.CheckedSupplier;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.TrustManager;
 import java.io.IOException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.TimeUnit;
@@ -232,4 +239,55 @@ public interface GrpcUtil {
       }
     }
   }
+
+  static SslContextBuilder initSslContextBuilderForServer(KeyManagerConf 
keyManagerConfig) {
+    final KeyManager keyManager = keyManagerConfig.getKeyManager();
+    if (keyManager != null) {
+      return SslContextBuilder.forServer(keyManager);
+    }
+    final PrivateKeyConf privateKey = keyManagerConfig.getPrivateKey();
+    final CertificatesConf certificates = 
keyManagerConfig.getKeyCertificates();
+
+    if (keyManagerConfig.isFileBased()) {
+      return SslContextBuilder.forServer(certificates.getFile(), 
privateKey.getFile());
+    } else {
+      return SslContextBuilder.forServer(privateKey.get(), certificates.get());
+    }
+  }
+
+  static void setTrustManager(SslContextBuilder b, TrustManagerConf 
trustManagerConfig) {
+    if (trustManagerConfig == null) {
+      return;
+    }
+    final TrustManager trustManager = trustManagerConfig.getTrustManager();
+    if (trustManager != null) {
+      b.trustManager(trustManager);
+      return;
+    }
+    final CertificatesConf certificates = 
trustManagerConfig.getTrustCertificates();
+    if (certificates.isFileBased()) {
+      b.trustManager(certificates.getFile());
+    } else {
+      b.trustManager(certificates.get());
+    }
+  }
+
+  static void setKeyManager(SslContextBuilder b, KeyManagerConf 
keyManagerConfig) {
+    if (keyManagerConfig == null) {
+      return;
+    }
+    final KeyManager keyManager = keyManagerConfig.getKeyManager();
+    if (keyManager != null) {
+      b.keyManager(keyManager);
+      return;
+    }
+    final PrivateKeyConf privateKey = keyManagerConfig.getPrivateKey();
+    final CertificatesConf certificates = 
keyManagerConfig.getKeyCertificates();
+
+    if (keyManagerConfig.isFileBased()) {
+      b.keyManager(certificates.getFile(), privateKey.getFile());
+    } else {
+      b.keyManager(privateKey.get(), certificates.get());
+    }
+  }
 }
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java
 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java
index d8b128a43..08bacdb73 100644
--- 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java
+++ 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java
@@ -131,19 +131,9 @@ public class GrpcClientProtocolClient implements Closeable 
{
 
     if (tlsConf != null) {
       SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
-      if (tlsConf.isFileBasedConfig()) {
-        sslContextBuilder.trustManager(tlsConf.getTrustStoreFile());
-      } else {
-        sslContextBuilder.trustManager(tlsConf.getTrustStore());
-      }
+      GrpcUtil.setTrustManager(sslContextBuilder, tlsConf.getTrustManager());
       if (tlsConf.getMtlsEnabled()) {
-        if (tlsConf.isFileBasedConfig()) {
-          sslContextBuilder.keyManager(tlsConf.getCertChainFile(),
-              tlsConf.getPrivateKeyFile());
-        } else {
-          sslContextBuilder.keyManager(tlsConf.getPrivateKey(),
-              tlsConf.getCertChain());
-        }
+        GrpcUtil.setKeyManager(sslContextBuilder, tlsConf.getKeyManager());
       }
       try {
         channelBuilder.useTransportSecurity().sslContext(
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
index 1494dd594..4c28c1df4 100644
--- 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
+++ 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
@@ -79,19 +79,9 @@ public class GrpcServerProtocolClient implements Closeable {
 
     if (tlsConfig!= null) {
       SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
-      if (tlsConfig.isFileBasedConfig()) {
-        sslContextBuilder.trustManager(tlsConfig.getTrustStoreFile());
-      } else {
-        sslContextBuilder.trustManager(tlsConfig.getTrustStore());
-      }
+      GrpcUtil.setTrustManager(sslContextBuilder, tlsConfig.getTrustManager());
       if (tlsConfig.getMtlsEnabled()) {
-        if (tlsConfig.isFileBasedConfig()) {
-          sslContextBuilder.keyManager(tlsConfig.getCertChainFile(),
-              tlsConfig.getPrivateKeyFile());
-        } else {
-          sslContextBuilder.keyManager(tlsConfig.getPrivateKey(),
-              tlsConfig.getCertChain());
-        }
+        GrpcUtil.setKeyManager(sslContextBuilder, tlsConfig.getKeyManager());
       }
       try {
         
channelBuilder.useTransportSecurity().sslContext(sslContextBuilder.build());
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
index f21619555..097900a0f 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
@@ -20,6 +20,7 @@ package org.apache.ratis.grpc.server;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.grpc.GrpcConfigKeys;
 import org.apache.ratis.grpc.GrpcTlsConfig;
+import org.apache.ratis.grpc.GrpcUtil;
 import org.apache.ratis.grpc.metrics.intercept.server.MetricServerInterceptor;
 import org.apache.ratis.protocol.RaftGroupId;
 import org.apache.ratis.protocol.RaftPeerId;
@@ -270,19 +271,10 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
         .flowControlWindow(flowControlWindow.getSizeInt());
 
     if (tlsConfig != null) {
-      SslContextBuilder sslContextBuilder =
-          tlsConfig.isFileBasedConfig()?
-              SslContextBuilder.forServer(tlsConfig.getCertChainFile(),
-                  tlsConfig.getPrivateKeyFile()):
-              SslContextBuilder.forServer(tlsConfig.getPrivateKey(),
-                  tlsConfig.getCertChain());
+      SslContextBuilder sslContextBuilder = 
GrpcUtil.initSslContextBuilderForServer(tlsConfig.getKeyManager());
       if (tlsConfig.getMtlsEnabled()) {
         sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
-        if (tlsConfig.isFileBasedConfig()) {
-          sslContextBuilder.trustManager(tlsConfig.getTrustStoreFile());
-        } else {
-            sslContextBuilder.trustManager(tlsConfig.getTrustStore());
-        }
+        GrpcUtil.setTrustManager(sslContextBuilder, 
tlsConfig.getTrustManager());
       }
       sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder, 
OPENSSL);
       try {
diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyUtils.java 
b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyUtils.java
index 146e1991c..ac37c801a 100644
--- a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyUtils.java
+++ b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyUtils.java
@@ -32,6 +32,8 @@ import org.apache.ratis.util.ConcurrentUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.TrustManager;
 import java.util.function.Function;
 
 public interface NettyUtils {
@@ -54,6 +56,11 @@ public interface NettyUtils {
     if (trustManagerConfig == null) {
       return;
     }
+    final TrustManager trustManager = trustManagerConfig.getTrustManager();
+    if (trustManager != null) {
+      b.trustManager(trustManager);
+      return;
+    }
     final CertificatesConf certificates = 
trustManagerConfig.getTrustCertificates();
     if (certificates.isFileBased()) {
       b.trustManager(certificates.getFile());
@@ -66,6 +73,11 @@ public interface NettyUtils {
     if (keyManagerConfig == null) {
       return;
     }
+    final KeyManager keyManager = keyManagerConfig.getKeyManager();
+    if (keyManager != null) {
+      b.keyManager(keyManager);
+      return;
+    }
     final PrivateKeyConf privateKey = keyManagerConfig.getPrivateKey();
     final CertificatesConf certificates = 
keyManagerConfig.getKeyCertificates();
 
@@ -77,6 +89,10 @@ public interface NettyUtils {
   }
 
   static SslContextBuilder initSslContextBuilderForServer(KeyManagerConf 
keyManagerConfig) {
+    final KeyManager keyManager = keyManagerConfig.getKeyManager();
+    if (keyManager != null) {
+      return SslContextBuilder.forServer(keyManager);
+    }
     final PrivateKeyConf privateKey = keyManagerConfig.getPrivateKey();
     final CertificatesConf certificates = 
keyManagerConfig.getKeyCertificates();
 
diff --git a/ratis-test/pom.xml b/ratis-test/pom.xml
index 99f4e21cc..b1a08f775 100644
--- a/ratis-test/pom.xml
+++ b/ratis-test/pom.xml
@@ -96,6 +96,19 @@
       </exclusions>
     </dependency>
 
+    <dependency>
+      <groupId>org.bouncycastle</groupId>
+      <artifactId>bcprov-jdk15on</artifactId>
+      <version>${bouncycastle.version}</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.bouncycastle</groupId>
+      <artifactId>bcpkix-jdk15on</artifactId>
+      <version>${bouncycastle.version}</version>
+      <scope>test</scope>
+    </dependency>
+
     <dependency>
       <groupId>org.slf4j</groupId>
       <artifactId>slf4j-api</artifactId>
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java 
b/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java
index e85b6ad19..a0e721b01 100644
--- a/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java
+++ b/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java
@@ -23,6 +23,9 @@ import static 
org.apache.ratis.server.metrics.RaftServerMetricsImpl.RAFT_CLIENT_
 import static 
org.apache.ratis.server.metrics.RaftServerMetricsImpl.RAFT_CLIENT_WRITE_REQUEST;
 
 import org.apache.ratis.metrics.RatisMetricRegistry;
+import org.slf4j.event.Level;
+import org.apache.ratis.conf.Parameters;
+import org.apache.ratis.security.SecurityTestUtils;
 import org.apache.ratis.server.storage.RaftStorage;
 import org.apache.ratis.BaseTest;
 import org.apache.ratis.metrics.impl.DefaultTimekeeperImpl;
@@ -61,8 +64,9 @@ import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
-import org.slf4j.event.Level;
 
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.TrustManager;
 import java.io.IOException;
 import java.nio.channels.OverlappingFileLockException;
 import java.util.*;
@@ -341,4 +345,38 @@ public class TestRaftServerWithGrpc extends BaseTest 
implements MiniRaftClusterW
     return RaftClientTestUtil.newRaftClientRequest(client, serverId, seqNum, m,
         RaftClientRequest.writeRequestType(), 
ProtoUtils.toSlidingWindowEntry(seqNum, seqNum == 1L));
   }
+
+  @Test
+  public void testTlsWithKeyAndTrustManager() throws Exception {
+    final RaftProperties p = getProperties();
+    RaftServerConfigKeys.Write.setElementLimit(p, 10);
+    RaftServerConfigKeys.Write.setByteLimit(p, SizeInBytes.valueOf("1MB"));
+    String[] ids = MiniRaftCluster.generateIds(3, 3);
+
+    KeyManager serverKeyManager = 
SecurityTestUtils.getKeyManager(SecurityTestUtils::getServerKeyStore);
+    TrustManager serverTrustManager = 
SecurityTestUtils.getTrustManager(SecurityTestUtils::getTrustStore);
+    KeyManager clientKeyManager = 
SecurityTestUtils.getKeyManager(SecurityTestUtils::getClientKeyStore);
+    TrustManager clientTrustManager = 
SecurityTestUtils.getTrustManager(SecurityTestUtils::getTrustStore);
+
+    GrpcTlsConfig serverConfig = new GrpcTlsConfig(serverKeyManager, 
serverTrustManager, true);
+    GrpcTlsConfig clientConfig = new GrpcTlsConfig(clientKeyManager, 
clientTrustManager, true);
+
+    final Parameters parameters = new Parameters();
+    GrpcConfigKeys.Server.setTlsConf(parameters, serverConfig);
+    GrpcConfigKeys.Admin.setTlsConf(parameters, clientConfig);
+    GrpcConfigKeys.Client.setTlsConf(parameters, clientConfig);
+
+    MiniRaftClusterWithGrpc cluster = null;
+    try {
+      cluster = new MiniRaftClusterWithGrpc(ids, new String[0], p, parameters);
+      cluster.start();
+      testRequestMetrics(cluster);
+    }  finally {
+      RaftServerConfigKeys.Write.setElementLimit(p, 
RaftServerConfigKeys.Write.ELEMENT_LIMIT_DEFAULT);
+      RaftServerConfigKeys.Write.setByteLimit(p, 
RaftServerConfigKeys.Write.BYTE_LIMIT_DEFAULT);
+      if (cluster != null) {
+        cluster.shutdown();
+      }
+    }
+  }
 }
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/security/SecurityTestUtils.java 
b/ratis-test/src/test/java/org/apache/ratis/security/SecurityTestUtils.java
new file mode 100644
index 000000000..a981282fd
--- /dev/null
+++ b/ratis-test/src/test/java/org/apache/ratis/security/SecurityTestUtils.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.security;
+
+import org.apache.ratis.security.TlsConf.Builder;
+import org.apache.ratis.security.TlsConf.CertificatesConf;
+import org.apache.ratis.security.TlsConf.PrivateKeyConf;
+import org.bouncycastle.util.io.pem.PemObject;
+import org.bouncycastle.util.io.pem.PemReader;
+import org.junit.Assert;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.TrustManagerFactory;
+import javax.net.ssl.X509TrustManager;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileReader;
+import java.net.URL;
+import java.security.KeyFactory;
+import java.security.KeyStore;
+import java.security.KeyStoreException;
+import java.security.NoSuchAlgorithmException;
+import java.security.PrivateKey;
+import java.security.UnrecoverableKeyException;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+import java.security.spec.PKCS8EncodedKeySpec;
+import java.util.Arrays;
+import java.util.Optional;
+import java.util.function.Supplier;
+
+public interface SecurityTestUtils {
+  Logger LOG = LoggerFactory.getLogger(SecurityTestUtils.class);
+
+  ClassLoader CLASS_LOADER = SecurityTestUtils.class.getClassLoader();
+
+  static File getResource(String name) {
+    final File file = Optional.ofNullable(CLASS_LOADER.getResource(name))
+        .map(URL::getFile)
+        .map(File::new)
+        .orElse(null);
+    LOG.info("Getting resource {}: {}", name, file);
+    return file;
+  }
+
+  static TlsConf newServerTlsConfig(boolean mutualAuthn) {
+    LOG.info("newServerTlsConfig: mutualAuthn? {}", mutualAuthn);
+    return new Builder()
+        .setName("server")
+        .setPrivateKey(new PrivateKeyConf(getResource("ssl/server.pem")))
+        .setKeyCertificates(new 
CertificatesConf(getResource("ssl/server.crt")))
+        .setTrustCertificates(new 
CertificatesConf(getResource("ssl/client.crt")))
+        .setMutualTls(mutualAuthn)
+        .build();
+  }
+
+  static TlsConf newClientTlsConfig(boolean mutualAuthn) {
+    LOG.info("newClientTlsConfig: mutualAuthn? {}", mutualAuthn);
+    return new Builder()
+        .setName("client")
+        .setPrivateKey(new PrivateKeyConf(getResource("ssl/client.pem")))
+        .setKeyCertificates(new 
CertificatesConf(getResource("ssl/client.crt")))
+        .setTrustCertificates(new CertificatesConf(getResource("ssl/ca.crt")))
+        .setMutualTls(mutualAuthn)
+        .build();
+  }
+
+  static PrivateKey getPrivateKey(String keyPath) {
+    try {
+      File file = getResource(keyPath);
+      FileReader keyReader = new FileReader(file);
+      PemReader pemReader = new PemReader(keyReader);
+      PemObject pemObject = pemReader.readPemObject();
+      pemReader.close();
+      keyReader.close();
+
+      byte[] content = pemObject.getContent();
+      PKCS8EncodedKeySpec privKeySpec = new PKCS8EncodedKeySpec(content);
+
+      KeyFactory keyFactory = KeyFactory.getInstance("RSA");
+      return keyFactory.generatePrivate(privKeySpec);
+    } catch (Exception e) {
+      Assert.fail("Failed to get private key from " + keyPath + ". Error: "  +
+          e.getMessage());
+    }
+    return null;
+  }
+
+  static X509Certificate[] getCertificate(String certPath) {
+    try {
+      // Read certificates
+      X509Certificate[] certificate = new X509Certificate[1];
+      CertificateFactory fact = CertificateFactory.getInstance("X.509");
+      try (FileInputStream is = new FileInputStream(getResource(certPath))) {
+        certificate[0] = (X509Certificate) fact.generateCertificate(is);
+      }
+      return certificate;
+    } catch (Exception e) {
+      Assert.fail("Failed to get certificate from " + certPath + ". Error: "  +
+          e.getMessage());
+    }
+    return null;
+  }
+
+  static KeyStore getServerKeyStore() {
+    try {
+      PrivateKey privateKey = getPrivateKey("ssl/server.pem");
+      X509Certificate[] certificate = getCertificate("ssl/server.crt");
+
+      // build keyStore
+      KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
+      keyStore.load(null, null);
+      keyStore.setKeyEntry("ratis-server-key", privateKey, new char[0], 
certificate);
+      return keyStore;
+    } catch (Exception e) {
+      Assert.fail("Failed to get sever key store " + e.getMessage());
+    }
+    return null;
+  }
+
+  static KeyStore getClientKeyStore() {
+    try {
+      PrivateKey privateKey = getPrivateKey("ssl/client.pem");
+      X509Certificate[] certificate = getCertificate("ssl/client.crt");
+
+      // build keyStore
+      KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
+      keyStore.load(null, null);
+      keyStore.setKeyEntry("ratis-client-key", privateKey, new char[0], 
certificate);
+      return keyStore;
+    } catch (Exception e) {
+      Assert.fail("Failed to get client key store " + e.getMessage());
+    }
+    return null;
+  }
+
+  static KeyStore getTrustStore() {
+    try {
+      X509Certificate[] certificate = getCertificate("ssl/ca.crt");
+
+      // build trustStore
+      KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType());
+      trustStore.load(null, null);
+
+      for (X509Certificate cert: certificate) {
+        trustStore.setCertificateEntry(cert.getSerialNumber().toString(), 
cert);
+      }
+      return trustStore;
+    } catch (Exception e) {
+      Assert.fail("Failed to get sever key store " + e.getMessage());
+    }
+    return null;
+  }
+
+  static KeyManager getKeyManager(Supplier<KeyStore> supplier) throws 
KeyStoreException,
+      NoSuchAlgorithmException, UnrecoverableKeyException {
+    KeyStore keyStore = supplier.get();
+
+    KeyManagerFactory keyManagerFactory = 
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+    keyManagerFactory.init(keyStore, new char[0]);
+
+    KeyManager[] managers = keyManagerFactory.getKeyManagers();
+    return managers[0];
+  }
+
+  static X509TrustManager getTrustManager(Supplier<KeyStore> supplier) throws 
KeyStoreException,
+      NoSuchAlgorithmException {
+    KeyStore keyStore = supplier.get();
+    TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(
+        TrustManagerFactory.getDefaultAlgorithm());
+    trustManagerFactory.init(keyStore);
+    TrustManager[] trustManagers = trustManagerFactory.getTrustManagers();
+    if (trustManagers.length != 1 || !(trustManagers[0] instanceof 
X509TrustManager)) {
+      throw new IllegalStateException("Unexpected default trust managers:"
+          + Arrays.toString(trustManagers));
+    }
+    return (X509TrustManager) trustManagers[0];
+  }
+}
\ No newline at end of file

Reply via email to