This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new d411b0288 RATIS-2331. Reuse SslContext in gRPC. (#1288)
d411b0288 is described below

commit d411b028882f2e397cee27cc5bf51e2f0dc4dd35
Author: Tsz-Wo Nicholas Sze <szets...@apache.org>
AuthorDate: Sat Sep 27 11:00:09 2025 -0700

    RATIS-2331. Reuse SslContext in gRPC. (#1288)
---
 .../main/java/org/apache/ratis/util/LifeCycle.java |   3 +
 .../java/org/apache/ratis/grpc/GrpcFactory.java    |  78 +++++++--------
 .../main/java/org/apache/ratis/grpc/GrpcUtil.java  |  40 ++++++++
 .../grpc/client/GrpcClientProtocolClient.java      |  28 ++----
 .../ratis/grpc/client/GrpcClientProtocolProxy.java | 108 ---------------------
 .../apache/ratis/grpc/client/GrpcClientRpc.java    |   6 +-
 .../grpc/server/GrpcServerProtocolClient.java      |  36 +++----
 .../apache/ratis/grpc/server/GrpcServicesImpl.java |  56 +++++------
 .../org/apache/ratis/grpc/server/GrpcStubPool.java |  29 ++----
 9 files changed, 134 insertions(+), 250 deletions(-)

diff --git a/ratis-common/src/main/java/org/apache/ratis/util/LifeCycle.java 
b/ratis-common/src/main/java/org/apache/ratis/util/LifeCycle.java
index 9870fe371..e96ba88a5 100644
--- a/ratis-common/src/main/java/org/apache/ratis/util/LifeCycle.java
+++ b/ratis-common/src/main/java/org/apache/ratis/util/LifeCycle.java
@@ -117,6 +117,9 @@ public class LifeCycle {
       if (LOG.isTraceEnabled()) {
         LOG.trace("TRACE", new Throwable());
       }
+      if (to == EXCEPTION) {
+        LOG.error("{} has failed ({} -> {})", name, from, to, new 
Throwable("TRACE"));
+      }
 
       Preconditions.assertTrue(isValid(from, to),
           "ILLEGAL TRANSITION: In %s, %s -> %s", name, from, to);
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java
index 331d1a858..1053cab80 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java
@@ -32,11 +32,15 @@ import org.apache.ratis.server.ServerFactory;
 import org.apache.ratis.server.leader.FollowerInfo;
 import org.apache.ratis.server.leader.LeaderState;
 import org.apache.ratis.thirdparty.io.netty.buffer.PooledByteBufAllocator;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
 import org.apache.ratis.util.JavaUtils;
+import org.apache.ratis.util.MemoizedSupplier;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
+import java.util.function.Supplier;
 
 public class GrpcFactory implements ServerFactory, ClientFactory {
 
@@ -65,19 +69,32 @@ public class GrpcFactory implements ServerFactory, 
ClientFactory {
     return value;
   }
 
-  private final GrpcServices.Customizer servicesCustomizer;
+  static final BiFunction<GrpcTlsConfig, SslContext, SslContext> 
BUILD_SSL_CONTEXT_FOR_SERVER
+      = (tlsConf, defaultContext) -> tlsConf == null ? defaultContext : 
GrpcUtil.buildSslContextForServer(tlsConf);
+
+  static final BiFunction<GrpcTlsConfig, SslContext, SslContext> 
BUILD_SSL_CONTEXT_FOR_CLIENT
+      = (tlsConf, defaultContext) -> tlsConf == null ? defaultContext : 
GrpcUtil.buildSslContextForClient(tlsConf);
 
-  private final GrpcTlsConfig tlsConfig;
-  private final GrpcTlsConfig adminTlsConfig;
-  private final GrpcTlsConfig clientTlsConfig;
-  private final GrpcTlsConfig serverTlsConfig;
+  static final class SslContexts {
+    private final SslContext adminSslContext;
+    private final SslContext clientSslContext;
+    private final SslContext serverSslContext;
 
-  public static Parameters newRaftParameters(GrpcTlsConfig conf) {
-    final Parameters p = new Parameters();
-    GrpcConfigKeys.TLS.setConf(p, conf);
-    return p;
+    private SslContexts(GrpcTlsConfig tlsConfig, GrpcTlsConfig adminTlsConfig,
+        GrpcTlsConfig clientTlsConfig, GrpcTlsConfig serverTlsConfig,
+        BiFunction<GrpcTlsConfig, SslContext, SslContext> buildMethod) {
+      final SslContext defaultSslContext = buildMethod.apply(tlsConfig, null);
+      this.adminSslContext = buildMethod.apply(adminTlsConfig, 
defaultSslContext);
+      this.clientSslContext = buildMethod.apply(clientTlsConfig, 
defaultSslContext);
+      this.serverSslContext = buildMethod.apply(serverTlsConfig, 
defaultSslContext);
+    }
   }
 
+  private final GrpcServices.Customizer servicesCustomizer;
+
+  private final Supplier<SslContexts> forServerSupplier;
+  private final Supplier<SslContexts> forClientSupplier;
+
   public GrpcFactory(Parameters parameters) {
     this(GrpcConfigKeys.Server.servicesCustomizer(parameters),
         GrpcConfigKeys.TLS.conf(parameters),
@@ -87,35 +104,15 @@ public class GrpcFactory implements ServerFactory, 
ClientFactory {
     );
   }
 
-  public GrpcFactory(GrpcTlsConfig tlsConfig) {
-    this(null, tlsConfig, null, null, null);
-  }
-
   private GrpcFactory(GrpcServices.Customizer servicesCustomizer,
       GrpcTlsConfig tlsConfig, GrpcTlsConfig adminTlsConfig,
       GrpcTlsConfig clientTlsConfig, GrpcTlsConfig serverTlsConfig) {
     this.servicesCustomizer = servicesCustomizer;
 
-    this.tlsConfig = tlsConfig;
-    this.adminTlsConfig = adminTlsConfig;
-    this.clientTlsConfig = clientTlsConfig;
-    this.serverTlsConfig = serverTlsConfig;
-  }
-
-  public GrpcTlsConfig getTlsConfig() {
-    return tlsConfig;
-  }
-
-  public GrpcTlsConfig getAdminTlsConfig() {
-    return adminTlsConfig != null ? adminTlsConfig : tlsConfig;
-  }
-
-  public GrpcTlsConfig getClientTlsConfig() {
-    return clientTlsConfig != null ? clientTlsConfig : tlsConfig;
-  }
-
-  public GrpcTlsConfig getServerTlsConfig() {
-    return serverTlsConfig != null ? serverTlsConfig : tlsConfig;
+    this.forServerSupplier = MemoizedSupplier.valueOf(() -> new SslContexts(
+        tlsConfig, adminTlsConfig, clientTlsConfig, serverTlsConfig, 
BUILD_SSL_CONTEXT_FOR_SERVER));
+    this.forClientSupplier = MemoizedSupplier.valueOf(() -> new SslContexts(
+        tlsConfig, adminTlsConfig, clientTlsConfig, serverTlsConfig, 
BUILD_SSL_CONTEXT_FOR_CLIENT));
   }
 
   @Override
@@ -131,19 +128,24 @@ public class GrpcFactory implements ServerFactory, 
ClientFactory {
   @Override
   public GrpcServices newRaftServerRpc(RaftServer server) {
     checkPooledByteBufAllocatorUseCacheForAllThreads(LOG::info);
+
+    final SslContexts forServer = forServerSupplier.get();
+    final SslContexts forClient = forClientSupplier.get();
     return GrpcServicesImpl.newBuilder()
         .setServer(server)
         .setCustomizer(servicesCustomizer)
-        .setAdminTlsConfig(getAdminTlsConfig())
-        .setServerTlsConfig(getServerTlsConfig())
-        .setClientTlsConfig(getClientTlsConfig())
+        .setAdminSslContext(forServer.adminSslContext)
+        .setServerSslContextForServer(forServer.serverSslContext)
+        .setServerSslContextForClient(forClient.serverSslContext)
+        .setClientSslContext(forServer.clientSslContext)
         .build();
   }
 
   @Override
   public GrpcClientRpc newRaftClientRpc(ClientId clientId, RaftProperties 
properties) {
     checkPooledByteBufAllocatorUseCacheForAllThreads(LOG::debug);
-    return new GrpcClientRpc(clientId, properties,
-        getAdminTlsConfig(), getClientTlsConfig());
+
+    final SslContexts forClient = forClientSupplier.get();
+    return new GrpcClientRpc(clientId, properties, forClient.adminSslContext, 
forClient.clientSslContext);
   }
 }
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
index 2f9ee01ec..8dcfb6544 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java
@@ -28,7 +28,10 @@ import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
 import org.apache.ratis.thirdparty.io.grpc.Metadata;
 import org.apache.ratis.thirdparty.io.grpc.Status;
 import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException;
+import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts;
 import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.ClientAuth;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
 import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
 import org.apache.ratis.util.IOUtils;
 import org.apache.ratis.util.JavaUtils;
@@ -39,6 +42,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.net.ssl.KeyManager;
+import javax.net.ssl.SSLException;
 import javax.net.ssl.TrustManager;
 import java.io.IOException;
 import java.util.concurrent.CompletableFuture;
@@ -46,6 +50,8 @@ import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 import java.util.function.Supplier;
 
+import static 
org.apache.ratis.thirdparty.io.netty.handler.ssl.SslProvider.OPENSSL;
+
 public interface GrpcUtil {
   Logger LOG = LoggerFactory.getLogger(GrpcUtil.class);
 
@@ -299,4 +305,38 @@ public interface GrpcUtil {
       b.keyManager(privateKey.get(), certificates.get());
     }
   }
+
+  static SslContext buildSslContextForServer(GrpcTlsConfig tlsConf) {
+    if (tlsConf == null) {
+      return null;
+    }
+    SslContextBuilder b = 
initSslContextBuilderForServer(tlsConf.getKeyManager());
+    if (tlsConf.getMtlsEnabled()) {
+      b.clientAuth(ClientAuth.REQUIRE);
+      setTrustManager(b, tlsConf.getTrustManager());
+    }
+    b = GrpcSslContexts.configure(b, OPENSSL);
+    try {
+      return b.build();
+    } catch (Exception e) {
+      throw new IllegalArgumentException("Failed to buildSslContextForServer 
from tlsConfig " + tlsConf, e);
+    }
+  }
+
+  static SslContext buildSslContextForClient(GrpcTlsConfig tlsConf) {
+    if (tlsConf == null) {
+      return null;
+    }
+
+    final SslContextBuilder b = GrpcSslContexts.forClient();
+    setTrustManager(b, tlsConf.getTrustManager());
+    if (tlsConf.getMtlsEnabled()) {
+      setKeyManager(b, tlsConf.getKeyManager());
+    }
+    try {
+      return b.build();
+    } catch (SSLException e) {
+      throw new IllegalArgumentException("Failed to buildSslContextForClient 
from tlsConfig " + tlsConf, e);
+    }
+  }
 }
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java
 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java
index 3b9d51268..159919fab 100644
--- 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java
+++ 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java
@@ -21,7 +21,6 @@ import org.apache.ratis.client.RaftClientConfigKeys;
 import org.apache.ratis.client.impl.ClientProtoUtils;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.grpc.GrpcConfigKeys;
-import org.apache.ratis.grpc.GrpcTlsConfig;
 import org.apache.ratis.grpc.GrpcUtil;
 import org.apache.ratis.grpc.metrics.intercept.client.MetricClientInterceptor;
 import org.apache.ratis.proto.RaftProtos.GroupInfoReplyProto;
@@ -49,11 +48,10 @@ import 
org.apache.ratis.protocol.exceptions.NotLeaderException;
 import org.apache.ratis.protocol.exceptions.TimeoutIOException;
 import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
 import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException;
-import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts;
 import org.apache.ratis.thirdparty.io.grpc.netty.NegotiationType;
 import org.apache.ratis.thirdparty.io.grpc.netty.NettyChannelBuilder;
 import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
-import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
 import org.apache.ratis.util.CollectionUtils;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.SizeInBytes;
@@ -97,7 +95,7 @@ public class GrpcClientProtocolClient implements Closeable {
   private final MetricClientInterceptor metricClientInterceptor;
 
   GrpcClientProtocolClient(ClientId id, RaftPeer target, RaftProperties 
properties,
-      GrpcTlsConfig adminTlsConfig, GrpcTlsConfig clientTlsConfig) {
+      SslContext adminSslContext, SslContext clientSslContext) {
     this.name = JavaUtils.memoize(() -> id + "->" + target.getId());
     this.target = target;
     final SizeInBytes flowControlWindow = 
GrpcConfigKeys.flowControlWindow(properties, LOG::debug);
@@ -110,11 +108,9 @@ public class GrpcClientProtocolClient implements Closeable 
{
         .filter(x -> !x.isEmpty()).orElse(target.getAddress());
     final boolean separateAdminChannel = !Objects.equals(clientAddress, 
adminAddress);
 
-    clientChannel = buildChannel(clientAddress, clientTlsConfig,
-        flowControlWindow, maxMessageSize);
+    clientChannel = buildChannel(clientAddress, clientSslContext, 
flowControlWindow, maxMessageSize);
     adminChannel = separateAdminChannel
-        ? buildChannel(adminAddress, adminTlsConfig,
-            flowControlWindow, maxMessageSize)
+        ? buildChannel(adminAddress, adminSslContext, flowControlWindow, 
maxMessageSize)
         : clientChannel;
 
     asyncStub = RaftClientProtocolServiceGrpc.newStub(clientChannel);
@@ -124,26 +120,16 @@ public class GrpcClientProtocolClient implements 
Closeable {
         RaftClientConfigKeys.Rpc.watchRequestTimeout(properties);
   }
 
-  private ManagedChannel buildChannel(String address, GrpcTlsConfig tlsConf,
+  private ManagedChannel buildChannel(String address, SslContext sslContext,
       SizeInBytes flowControlWindow, SizeInBytes maxMessageSize) {
     NettyChannelBuilder channelBuilder =
         NettyChannelBuilder.forTarget(address);
     // ignore any http proxy for grpc
     channelBuilder.proxyDetector(uri -> null);
 
-    if (tlsConf != null) {
+    if (sslContext != null) {
       LOG.debug("Setting TLS for {}", address);
-      SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
-      GrpcUtil.setTrustManager(sslContextBuilder, tlsConf.getTrustManager());
-      if (tlsConf.getMtlsEnabled()) {
-        GrpcUtil.setKeyManager(sslContextBuilder, tlsConf.getKeyManager());
-      }
-      try {
-        channelBuilder.useTransportSecurity().sslContext(
-            sslContextBuilder.build());
-      } catch (Exception ex) {
-        throw new RuntimeException(ex);
-      }
+      channelBuilder.useTransportSecurity().sslContext(sslContext);
     } else {
       channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
     }
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolProxy.java
 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolProxy.java
deleted file mode 100644
index 95119ef7d..000000000
--- 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolProxy.java
+++ /dev/null
@@ -1,108 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.ratis.grpc.client;
-
-import org.apache.ratis.conf.RaftProperties;
-import org.apache.ratis.grpc.GrpcTlsConfig;
-import org.apache.ratis.protocol.ClientId;
-import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
-import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
-import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto;
-import org.apache.ratis.protocol.RaftPeer;
-
-import java.io.Closeable;
-import java.io.IOException;
-import java.util.function.Function;
-
-public class GrpcClientProtocolProxy implements Closeable {
-  private final GrpcClientProtocolClient proxy;
-  private final Function<RaftPeer, CloseableStreamObserver> 
responseHandlerCreation;
-  private RpcSession currentSession;
-
-  public GrpcClientProtocolProxy(ClientId clientId, RaftPeer target,
-      Function<RaftPeer, CloseableStreamObserver> responseHandlerCreation,
-      RaftProperties properties, GrpcTlsConfig tlsConfig) {
-    proxy = new GrpcClientProtocolClient(clientId, target, properties, 
tlsConfig, tlsConfig);
-    this.responseHandlerCreation = responseHandlerCreation;
-  }
-
-  @Override
-  public void close() throws IOException {
-    closeCurrentSession();
-    proxy.close();
-  }
-
-  @Override
-  public String toString() {
-    return "ProxyTo:" + proxy.getTarget();
-  }
-
-  public void closeCurrentSession() {
-    if (currentSession != null) {
-      currentSession.close();
-      currentSession = null;
-    }
-  }
-
-  public void onNext(RaftClientRequestProto request) {
-    if (currentSession == null) {
-      currentSession = new RpcSession(
-          responseHandlerCreation.apply(proxy.getTarget()));
-    }
-    currentSession.requestObserver.onNext(request);
-  }
-
-  public void onError() {
-    if (currentSession != null) {
-      currentSession.onError();
-    }
-  }
-
-  public interface CloseableStreamObserver
-      extends StreamObserver<RaftClientReplyProto>, Closeable {
-  }
-
-  class RpcSession implements Closeable {
-    private final StreamObserver<RaftClientRequestProto> requestObserver;
-    private final CloseableStreamObserver responseHandler;
-    private boolean hasError = false;
-
-    RpcSession(CloseableStreamObserver responseHandler) {
-      this.responseHandler = responseHandler;
-      this.requestObserver = proxy.ordered(responseHandler);
-    }
-
-    void onError() {
-      hasError = true;
-    }
-
-    @Override
-    public void close() {
-      if (!hasError) {
-        try {
-          requestObserver.onCompleted();
-        } catch (Exception ignored) {
-        }
-      }
-      try {
-        responseHandler.close();
-      } catch (IOException ignored) {
-      }
-    }
-  }
-}
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientRpc.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientRpc.java
index b825429ae..4010ade27 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientRpc.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientRpc.java
@@ -21,7 +21,6 @@ import org.apache.ratis.client.impl.ClientProtoUtils;
 import org.apache.ratis.client.impl.RaftClientRpcWithProxy;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.grpc.GrpcConfigKeys;
-import org.apache.ratis.grpc.GrpcTlsConfig;
 import org.apache.ratis.grpc.GrpcUtil;
 import org.apache.ratis.protocol.*;
 import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
@@ -36,6 +35,7 @@ import 
org.apache.ratis.proto.RaftProtos.SetConfigurationRequestProto;
 import org.apache.ratis.proto.RaftProtos.TransferLeadershipRequestProto;
 import org.apache.ratis.proto.RaftProtos.SnapshotManagementRequestProto;
 import org.apache.ratis.proto.RaftProtos.LeaderElectionManagementRequestProto;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
 import org.apache.ratis.util.IOUtils;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.PeerProxyMap;
@@ -54,9 +54,9 @@ public class GrpcClientRpc extends 
RaftClientRpcWithProxy<GrpcClientProtocolClie
   private final int maxMessageSize;
 
   public GrpcClientRpc(ClientId clientId, RaftProperties properties,
-      GrpcTlsConfig adminTlsConfig, GrpcTlsConfig clientTlsConfig) {
+      SslContext adminSslContext, SslContext clientSslContext) {
     super(new PeerProxyMap<>(clientId.toString(),
-        p -> new GrpcClientProtocolClient(clientId, p, properties, 
adminTlsConfig, clientTlsConfig)));
+        p -> new GrpcClientProtocolClient(clientId, p, properties, 
adminSslContext, clientSslContext)));
     this.clientId = clientId;
     this.maxMessageSize = GrpcConfigKeys.messageSizeMax(properties, 
LOG::debug).getSizeInt();
   }
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
index 2e936bb0b..a0a17dc9f 100644
--- 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
+++ 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
@@ -1,4 +1,4 @@
-/**
+/*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
@@ -17,13 +17,11 @@
  */
 package org.apache.ratis.grpc.server;
 
-import org.apache.ratis.grpc.GrpcTlsConfig;
 import org.apache.ratis.grpc.GrpcUtil;
 import org.apache.ratis.grpc.util.StreamObserverWithTimeout;
 import org.apache.ratis.protocol.RaftPeerId;
 import org.apache.ratis.server.util.ServerStringUtils;
 import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
-import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts;
 import org.apache.ratis.thirdparty.io.grpc.netty.NegotiationType;
 import org.apache.ratis.thirdparty.io.grpc.netty.NettyChannelBuilder;
 import org.apache.ratis.thirdparty.io.grpc.stub.CallStreamObserver;
@@ -33,7 +31,7 @@ import 
org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc;
 import 
org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceBlockingStub;
 import 
org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceStub;
 import org.apache.ratis.protocol.RaftPeer;
-import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
 import org.apache.ratis.util.TimeDuration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -44,7 +42,7 @@ import java.io.Closeable;
  * This is a RaftClient implementation that supports streaming data to the raft
  * ring. The stream implementation utilizes gRPC.
  */
-public class GrpcServerProtocolClient implements Closeable {
+class GrpcServerProtocolClient implements Closeable {
   // Common channel
   private final ManagedChannel channel;
   private final GrpcStubPool<RaftServerProtocolServiceStub> pool;
@@ -60,42 +58,30 @@ public class GrpcServerProtocolClient implements Closeable {
   //visible for using in log / error messages AND to use in instrumented tests
   private final RaftPeerId raftPeerId;
 
-  public GrpcServerProtocolClient(RaftPeer target, int connections, int 
flowControlWindow,
-      TimeDuration requestTimeout, GrpcTlsConfig tlsConfig, boolean 
separateHBChannel) {
+  GrpcServerProtocolClient(RaftPeer target, int connections, int 
flowControlWindow,
+      TimeDuration requestTimeout, SslContext sslContext, boolean 
separateHBChannel) {
     raftPeerId = target.getId();
     LOG.info("Build channel for {}", target);
     useSeparateHBChannel = separateHBChannel;
-    channel = buildChannel(target, flowControlWindow, tlsConfig);
+    channel = buildChannel(target, flowControlWindow, sslContext);
     blockingStub = RaftServerProtocolServiceGrpc.newBlockingStub(channel);
     asyncStub = RaftServerProtocolServiceGrpc.newStub(channel);
     if (useSeparateHBChannel) {
-      hbChannel = buildChannel(target, flowControlWindow, tlsConfig);
+      hbChannel = buildChannel(target, flowControlWindow, sslContext);
       hbAsyncStub = RaftServerProtocolServiceGrpc.newStub(hbChannel);
     }
     requestTimeoutDuration = requestTimeout;
-    this.pool = new GrpcStubPool<RaftServerProtocolServiceStub>(target, 
connections,
-            ch -> RaftServerProtocolServiceGrpc.newStub(ch), tlsConfig);
+    this.pool = new GrpcStubPool<>(target, connections, 
RaftServerProtocolServiceGrpc::newStub, sslContext);
   }
 
-  private ManagedChannel buildChannel(RaftPeer target, int flowControlWindow,
-      GrpcTlsConfig tlsConfig) {
+  private ManagedChannel buildChannel(RaftPeer target, int flowControlWindow, 
SslContext sslContext) {
     NettyChannelBuilder channelBuilder =
         NettyChannelBuilder.forTarget(target.getAddress());
     // ignore any http proxy for grpc
     channelBuilder.proxyDetector(uri -> null);
 
-    if (tlsConfig!= null) {
-      SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
-      GrpcUtil.setTrustManager(sslContextBuilder, tlsConfig.getTrustManager());
-      if (tlsConfig.getMtlsEnabled()) {
-        GrpcUtil.setKeyManager(sslContextBuilder, tlsConfig.getKeyManager());
-      }
-      try {
-        
channelBuilder.useTransportSecurity().sslContext(sslContextBuilder.build());
-      } catch (Exception ex) {
-        throw new IllegalArgumentException("Failed to build SslContext, 
peerId=" + raftPeerId
-            + ", tlsConfig=" + tlsConfig, ex);
-      }
+    if (sslContext != null) {
+      channelBuilder.useTransportSecurity().sslContext(sslContext);
     } else {
       channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
     }
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
index b686be0a2..b1af0960d 100644
--- 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
+++ 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
@@ -19,8 +19,6 @@ package org.apache.ratis.grpc.server;
 
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.grpc.GrpcConfigKeys;
-import org.apache.ratis.grpc.GrpcTlsConfig;
-import org.apache.ratis.grpc.GrpcUtil;
 import org.apache.ratis.grpc.metrics.MessageMetrics;
 import org.apache.ratis.grpc.metrics.intercept.server.MetricServerInterceptor;
 import org.apache.ratis.protocol.AdminAsynchronousProtocol;
@@ -34,13 +32,11 @@ import org.apache.ratis.server.RaftServerRpcWithProxy;
 import org.apache.ratis.server.protocol.RaftServerAsynchronousProtocol;
 import org.apache.ratis.thirdparty.io.grpc.ServerInterceptor;
 import org.apache.ratis.thirdparty.io.grpc.ServerInterceptors;
-import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts;
 import org.apache.ratis.thirdparty.io.grpc.netty.NettyServerBuilder;
 import org.apache.ratis.thirdparty.io.grpc.Server;
 import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
 import org.apache.ratis.thirdparty.io.netty.channel.ChannelOption;
-import org.apache.ratis.thirdparty.io.netty.handler.ssl.ClientAuth;
-import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
 
 import org.apache.ratis.proto.RaftProtos.*;
 import org.apache.ratis.util.*;
@@ -56,8 +52,6 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Supplier;
 
-import static 
org.apache.ratis.thirdparty.io.netty.handler.ssl.SslProvider.OPENSSL;
-
 /** A grpc implementation of {@link org.apache.ratis.server.RaftServerRpc}. */
 public final class GrpcServicesImpl
     extends RaftServerRpcWithProxy<GrpcServerProtocolClient, 
PeerProxyMap<GrpcServerProtocolClient>>
@@ -106,13 +100,14 @@ public final class GrpcServicesImpl
 
     private String adminHost;
     private int adminPort;
-    private GrpcTlsConfig adminTlsConfig;
+    private SslContext adminSslContext;
     private String clientHost;
     private int clientPort;
-    private GrpcTlsConfig clientTlsConfig;
+    private SslContext clientSslContext;
     private String serverHost;
     private int serverPort;
-    private GrpcTlsConfig serverTlsConfig;
+    private SslContext serverSslContextForServer;
+    private SslContext serverSslContextForClient;
     private int serverStubPoolSize;
 
     private SizeInBytes messageSizeMax;
@@ -158,7 +153,7 @@ public final class GrpcServicesImpl
 
     private GrpcServerProtocolClient newGrpcServerProtocolClient(RaftPeer 
target) {
       return new GrpcServerProtocolClient(target, serverStubPoolSize, 
flowControlWindow.getSizeInt(),
-          requestTimeoutDuration, serverTlsConfig, separateHeartbeatChannel);
+          requestTimeoutDuration, serverSslContextForClient, 
separateHeartbeatChannel);
     }
 
     private ExecutorService newExecutor() {
@@ -188,18 +183,18 @@ public final class GrpcServicesImpl
     }
 
     private NettyServerBuilder newNettyServerBuilderForServer() {
-      return newNettyServerBuilder(serverHost, serverPort, serverTlsConfig);
+      return newNettyServerBuilder(serverHost, serverPort, 
serverSslContextForServer);
     }
 
     private NettyServerBuilder newNettyServerBuilderForAdmin() {
-      return newNettyServerBuilder(adminHost, adminPort, adminTlsConfig);
+      return newNettyServerBuilder(adminHost, adminPort, adminSslContext);
     }
 
     private NettyServerBuilder newNettyServerBuilderForClient() {
-      return newNettyServerBuilder(clientHost, clientPort, clientTlsConfig);
+      return newNettyServerBuilder(clientHost, clientPort, clientSslContext);
     }
 
-    private NettyServerBuilder newNettyServerBuilder(String hostname, int 
port, GrpcTlsConfig tlsConfig) {
+    private NettyServerBuilder newNettyServerBuilder(String hostname, int 
port, SslContext sslContext) {
       final InetSocketAddress address = hostname == null || hostname.isEmpty() 
?
           new InetSocketAddress(port) : new InetSocketAddress(hostname, port);
       final NettyServerBuilder nettyServerBuilder = 
NettyServerBuilder.forAddress(address)
@@ -207,19 +202,9 @@ public final class GrpcServicesImpl
           .maxInboundMessageSize(messageSizeMax.getSizeInt())
           .flowControlWindow(flowControlWindow.getSizeInt());
 
-      if (tlsConfig != null) {
+      if (sslContext != null) {
         LOG.info("Setting TLS for {}", address);
-        SslContextBuilder sslContextBuilder = 
GrpcUtil.initSslContextBuilderForServer(tlsConfig.getKeyManager());
-        if (tlsConfig.getMtlsEnabled()) {
-          sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
-          GrpcUtil.setTrustManager(sslContextBuilder, 
tlsConfig.getTrustManager());
-        }
-        sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder, 
OPENSSL);
-        try {
-          nettyServerBuilder.sslContext(sslContextBuilder.build());
-        } catch (Exception ex) {
-          throw new IllegalArgumentException("Failed to build SslContext, 
tlsConfig=" + tlsConfig, ex);
-        }
+        nettyServerBuilder.sslContext(sslContext);
       }
       return nettyServerBuilder;
     }
@@ -253,18 +238,23 @@ public final class GrpcServicesImpl
       return new GrpcServicesImpl(this);
     }
 
-    public Builder setAdminTlsConfig(GrpcTlsConfig config) {
-      this.adminTlsConfig = config;
+    public Builder setAdminSslContext(SslContext adminSslContext) {
+      this.adminSslContext = adminSslContext;
+      return this;
+    }
+
+    public Builder setClientSslContext(SslContext clientSslContext) {
+      this.clientSslContext = clientSslContext;
       return this;
     }
 
-    public Builder setClientTlsConfig(GrpcTlsConfig config) {
-      this.clientTlsConfig = config;
+    public Builder setServerSslContextForServer(SslContext 
serverSslContextForServer) {
+      this.serverSslContextForServer = serverSslContextForServer;
       return this;
     }
 
-    public Builder setServerTlsConfig(GrpcTlsConfig config) {
-      this.serverTlsConfig = config;
+    public Builder setServerSslContextForClient(SslContext 
serverSslContextForClient) {
+      this.serverSslContextForClient = serverSslContextForClient;
       return this;
     }
   }
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcStubPool.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcStubPool.java
index fcfb0f1b8..c949707a4 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcStubPool.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcStubPool.java
@@ -17,11 +17,8 @@
  */
 package org.apache.ratis.grpc.server;
 
-import org.apache.ratis.grpc.GrpcTlsConfig;
-import org.apache.ratis.grpc.GrpcUtil;
 import org.apache.ratis.protocol.RaftPeer;
 import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
-import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts;
 import org.apache.ratis.thirdparty.io.grpc.netty.NegotiationType;
 import org.apache.ratis.thirdparty.io.grpc.netty.NettyChannelBuilder;
 import org.apache.ratis.thirdparty.io.grpc.stub.AbstractStub;
@@ -29,7 +26,7 @@ import 
org.apache.ratis.thirdparty.io.netty.channel.ChannelOption;
 import org.apache.ratis.thirdparty.io.netty.channel.WriteBufferWaterMark;
 import org.apache.ratis.thirdparty.io.netty.channel.nio.NioEventLoopGroup;
 import 
org.apache.ratis.thirdparty.io.netty.channel.socket.nio.NioSocketChannel;
-import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -39,7 +36,6 @@ import java.util.List;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Function;
 
 final class GrpcStubPool<S extends AbstractStub<S>> {
@@ -66,16 +62,14 @@ final class GrpcStubPool<S extends AbstractStub<S>> {
   }
 
   private final List<PooledStub<S>> pool;
-  private final AtomicInteger rr = new AtomicInteger();
   private final NioEventLoopGroup elg;
   private final int size;
 
-  GrpcStubPool(RaftPeer target, int n, Function<ManagedChannel, S> 
stubFactory, GrpcTlsConfig tlsConfig) {
-    this(target, n, stubFactory, tlsConfig, Math.max(2, 
Runtime.getRuntime().availableProcessors() / 2), 16);
+  GrpcStubPool(RaftPeer target, int n, Function<ManagedChannel, S> 
stubFactory, SslContext sslContext) {
+    this(target, n, stubFactory, sslContext, Math.max(2, 
Runtime.getRuntime().availableProcessors() / 2), 16);
   }
 
-  GrpcStubPool(RaftPeer target, int n,
-               Function<ManagedChannel, S> stubFactory, GrpcTlsConfig tlsConf,
+  GrpcStubPool(RaftPeer target, int n, Function<ManagedChannel, S> 
stubFactory, SslContext sslContext,
                int elgThreads, int maxInflightPerConn) {
     this.elg = new NioEventLoopGroup(elgThreads);
     ArrayList<PooledStub<S>> tmp = new ArrayList<>(n);
@@ -87,18 +81,9 @@ final class GrpcStubPool<S extends AbstractStub<S>> {
           .keepAliveWithoutCalls(true)
           .idleTimeout(24, TimeUnit.HOURS)
           .withOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new 
WriteBufferWaterMark(64 << 10, 128 << 10));
-      if (tlsConf != null) {
+      if (sslContext != null) {
         LOG.debug("Setting TLS for {}", target.getAddress());
-        SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
-        GrpcUtil.setTrustManager(sslContextBuilder, tlsConf.getTrustManager());
-        if (tlsConf.getMtlsEnabled()) {
-          GrpcUtil.setKeyManager(sslContextBuilder, tlsConf.getKeyManager());
-        }
-        try {
-          
channelBuilder.useTransportSecurity().sslContext(sslContextBuilder.build());
-        } catch (Exception ex) {
-          throw new RuntimeException(ex);
-        }
+        channelBuilder.useTransportSecurity().sslContext(sslContext);
       } else {
         channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
       }
@@ -124,7 +109,7 @@ final class GrpcStubPool<S extends AbstractStub<S>> {
   }
 
   public void close() {
-    for (PooledStub p : pool) {
+    for (PooledStub<S> p : pool) {
       p.ch.shutdown();
     }
     elg.shutdownGracefully();


Reply via email to