This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new e96ed1a33 RATIS-2168. Support custom gRPC services. (#1169)
e96ed1a33 is described below

commit e96ed1a33840385446f4e647864a169467da5ab7
Author: Tsz-Wo Nicholas Sze <[email protected]>
AuthorDate: Wed Oct 16 18:06:59 2024 -0700

    RATIS-2168. Support custom gRPC services. (#1169)
---
 ratis-grpc/pom.xml                                 |   4 +
 .../java/org/apache/ratis/grpc/GrpcConfigKeys.java |  29 ++-
 .../java/org/apache/ratis/grpc/GrpcFactory.java    |  19 +-
 .../apache/ratis/grpc/server/GrpcLogAppender.java  |   6 +-
 .../org/apache/ratis/grpc/server/GrpcServices.java |  56 ++++++
 .../{GrpcService.java => GrpcServicesImpl.java}    |  45 +++--
 .../apache/ratis/grpc/MiniRaftClusterWithGrpc.java |   6 +-
 .../apache/ratis/grpc/TestCustomGrpcServices.java  | 205 +++++++++++++++++++++
 .../grpc/{ => server}/TestGrpcMessageMetrics.java  |   8 +-
 9 files changed, 339 insertions(+), 39 deletions(-)

diff --git a/ratis-grpc/pom.xml b/ratis-grpc/pom.xml
index 2e542215c..fc6797831 100644
--- a/ratis-grpc/pom.xml
+++ b/ratis-grpc/pom.xml
@@ -53,6 +53,10 @@
       <scope>test</scope>
       <type>test-jar</type>
     </dependency>
+    <dependency>
+      <groupId>org.apache.ratis</groupId>
+      <artifactId>ratis-server-api</artifactId>
+    </dependency>
     <dependency>
       <artifactId>ratis-server</artifactId>
       <groupId>org.apache.ratis</groupId>
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java
index c14d844ee..2fcb9b6b0 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java
@@ -19,6 +19,7 @@ package org.apache.ratis.grpc;
 
 import org.apache.ratis.conf.Parameters;
 import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.grpc.server.GrpcServices;
 import org.apache.ratis.server.RaftServerConfigKeys;
 import org.apache.ratis.util.SizeInBytes;
 import org.apache.ratis.util.TimeDuration;
@@ -230,15 +231,6 @@ public interface GrpcConfigKeys {
       setInt(properties::setInt, ASYNC_REQUEST_THREAD_POOL_SIZE_KEY, port);
     }
 
-    String TLS_CONF_PARAMETER = PREFIX + ".tls.conf";
-    Class<GrpcTlsConfig> TLS_CONF_CLASS = TLS.CONF_CLASS;
-    static GrpcTlsConfig tlsConf(Parameters parameters) {
-      return parameters != null ? parameters.get(TLS_CONF_PARAMETER, 
TLS_CONF_CLASS): null;
-    }
-    static void setTlsConf(Parameters parameters, GrpcTlsConfig conf) {
-      parameters.put(TLS_CONF_PARAMETER, conf, TLS_CONF_CLASS);
-    }
-
     String LEADER_OUTSTANDING_APPENDS_MAX_KEY = PREFIX + 
".leader.outstanding.appends.max";
     int LEADER_OUTSTANDING_APPENDS_MAX_DEFAULT = 8;
     static int leaderOutstandingAppendsMax(RaftProperties properties) {
@@ -301,6 +293,25 @@ public interface GrpcConfigKeys {
     static void setZeroCopyEnabled(RaftProperties properties, boolean enabled) 
{
       setBoolean(properties::setBoolean, ZERO_COPY_ENABLED_KEY, enabled);
     }
+
+    String SERVICES_CUSTOMIZER_PARAMETER = PREFIX + ".services.customizer";
+    Class<GrpcServices.Customizer> SERVICES_CUSTOMIZER_CLASS = 
GrpcServices.Customizer.class;
+    static GrpcServices.Customizer servicesCustomizer(Parameters parameters) {
+      return parameters == null ? null
+          : parameters.get(SERVICES_CUSTOMIZER_PARAMETER, 
SERVICES_CUSTOMIZER_CLASS);
+    }
+    static void setServicesCustomizer(Parameters parameters, 
GrpcServices.Customizer customizer) {
+      parameters.put(SERVICES_CUSTOMIZER_PARAMETER, customizer, 
SERVICES_CUSTOMIZER_CLASS);
+    }
+
+    String TLS_CONF_PARAMETER = PREFIX + ".tls.conf";
+    Class<GrpcTlsConfig> TLS_CONF_CLASS = TLS.CONF_CLASS;
+    static GrpcTlsConfig tlsConf(Parameters parameters) {
+      return parameters != null ? parameters.get(TLS_CONF_PARAMETER, 
TLS_CONF_CLASS): null;
+    }
+    static void setTlsConf(Parameters parameters, GrpcTlsConfig conf) {
+      parameters.put(TLS_CONF_PARAMETER, conf, TLS_CONF_CLASS);
+    }
   }
 
   String MESSAGE_SIZE_MAX_KEY = PREFIX + ".message.size.max";
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java
index 75eb34a2d..331d1a858 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java
@@ -22,7 +22,8 @@ import org.apache.ratis.conf.Parameters;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.grpc.client.GrpcClientRpc;
 import org.apache.ratis.grpc.server.GrpcLogAppender;
-import org.apache.ratis.grpc.server.GrpcService;
+import org.apache.ratis.grpc.server.GrpcServices;
+import org.apache.ratis.grpc.server.GrpcServicesImpl;
 import org.apache.ratis.protocol.ClientId;
 import org.apache.ratis.rpc.SupportedRpcType;
 import org.apache.ratis.server.RaftServer;
@@ -64,6 +65,8 @@ public class GrpcFactory implements ServerFactory, 
ClientFactory {
     return value;
   }
 
+  private final GrpcServices.Customizer servicesCustomizer;
+
   private final GrpcTlsConfig tlsConfig;
   private final GrpcTlsConfig adminTlsConfig;
   private final GrpcTlsConfig clientTlsConfig;
@@ -76,7 +79,7 @@ public class GrpcFactory implements ServerFactory, 
ClientFactory {
   }
 
   public GrpcFactory(Parameters parameters) {
-    this(
+    this(GrpcConfigKeys.Server.servicesCustomizer(parameters),
         GrpcConfigKeys.TLS.conf(parameters),
         GrpcConfigKeys.Admin.tlsConf(parameters),
         GrpcConfigKeys.Client.tlsConf(parameters),
@@ -85,11 +88,14 @@ public class GrpcFactory implements ServerFactory, 
ClientFactory {
   }
 
   public GrpcFactory(GrpcTlsConfig tlsConfig) {
-    this(tlsConfig, null, null, null);
+    this(null, tlsConfig, null, null, null);
   }
 
-  private GrpcFactory(GrpcTlsConfig tlsConfig, GrpcTlsConfig adminTlsConfig,
+  private GrpcFactory(GrpcServices.Customizer servicesCustomizer,
+      GrpcTlsConfig tlsConfig, GrpcTlsConfig adminTlsConfig,
       GrpcTlsConfig clientTlsConfig, GrpcTlsConfig serverTlsConfig) {
+    this.servicesCustomizer = servicesCustomizer;
+
     this.tlsConfig = tlsConfig;
     this.adminTlsConfig = adminTlsConfig;
     this.clientTlsConfig = clientTlsConfig;
@@ -123,10 +129,11 @@ public class GrpcFactory implements ServerFactory, 
ClientFactory {
   }
 
   @Override
-  public GrpcService newRaftServerRpc(RaftServer server) {
+  public GrpcServices newRaftServerRpc(RaftServer server) {
     checkPooledByteBufAllocatorUseCacheForAllThreads(LOG::info);
-    return GrpcService.newBuilder()
+    return GrpcServicesImpl.newBuilder()
         .setServer(server)
+        .setCustomizer(servicesCustomizer)
         .setAdminTlsConfig(getAdminTlsConfig())
         .setServerTlsConfig(getServerTlsConfig())
         .setClientTlsConfig(getClientTlsConfig())
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java
index 18d4c62c6..45bc4c888 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java
@@ -192,8 +192,8 @@ public class GrpcLogAppender extends LogAppenderBase {
   }
 
   @Override
-  public GrpcService getServerRpc() {
-    return (GrpcService)super.getServerRpc();
+  public GrpcServicesImpl getServerRpc() {
+    return (GrpcServicesImpl)super.getServerRpc();
   }
 
   private GrpcServerProtocolClient getClient() throws IOException {
@@ -428,7 +428,7 @@ public class GrpcLogAppender extends LogAppenderBase {
 
   private void sendRequest(AppendEntriesRequest request,
       AppendEntriesRequestProto proto) throws InterruptedIOException {
-    CodeInjectionForTesting.execute(GrpcService.GRPC_SEND_SERVER_REQUEST,
+    CodeInjectionForTesting.execute(GrpcServicesImpl.GRPC_SEND_SERVER_REQUEST,
         getServer().getId(), null, proto);
     resetHeartbeatTrigger();
 
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServices.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServices.java
new file mode 100644
index 000000000..663fd6d74
--- /dev/null
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServices.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.server;
+
+import org.apache.ratis.server.RaftServerRpc;
+import org.apache.ratis.thirdparty.io.grpc.netty.NettyServerBuilder;
+
+import java.util.EnumSet;
+
+/** The gRPC services extending {@link RaftServerRpc}. */
+public interface GrpcServices extends RaftServerRpc {
+  /** The type of the services. */
+  enum Type {ADMIN, CLIENT, SERVER}
+
+  /**
+   * To customize the services.
+   * For example, add a custom service.
+   */
+  interface Customizer {
+    /** The default NOOP {@link Customizer}. */
+    class Default implements Customizer {
+      private static final Default INSTANCE = new Default();
+
+      @Override
+      public NettyServerBuilder customize(NettyServerBuilder builder, 
EnumSet<GrpcServices.Type> types) {
+        return builder;
+      }
+    }
+
+    static Customizer getDefaultInstance() {
+      return Default.INSTANCE;
+    }
+
+    /**
+     * Customize the given builder for the given types.
+     *
+     * @return the customized builder.
+     */
+    NettyServerBuilder customize(NettyServerBuilder builder, 
EnumSet<GrpcServices.Type> types);
+  }
+}
diff --git 
a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
similarity index 92%
rename from 
ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
rename to 
ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
index 510dfcaa2..d6f6a0c86 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
+++ 
b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
@@ -21,6 +21,7 @@ import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.grpc.GrpcConfigKeys;
 import org.apache.ratis.grpc.GrpcTlsConfig;
 import org.apache.ratis.grpc.GrpcUtil;
+import org.apache.ratis.grpc.metrics.MessageMetrics;
 import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
 import org.apache.ratis.grpc.metrics.intercept.server.MetricServerInterceptor;
 import org.apache.ratis.protocol.AdminAsynchronousProtocol;
@@ -51,6 +52,7 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
@@ -60,11 +62,12 @@ import java.util.function.Supplier;
 import static 
org.apache.ratis.thirdparty.io.netty.handler.ssl.SslProvider.OPENSSL;
 
 /** A grpc implementation of {@link org.apache.ratis.server.RaftServerRpc}. */
-public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocolClient,
-    PeerProxyMap<GrpcServerProtocolClient>> {
-  static final Logger LOG = LoggerFactory.getLogger(GrpcService.class);
+public final class GrpcServicesImpl
+    extends RaftServerRpcWithProxy<GrpcServerProtocolClient, 
PeerProxyMap<GrpcServerProtocolClient>>
+    implements GrpcServices {
+  static final Logger LOG = LoggerFactory.getLogger(GrpcServicesImpl.class);
   public static final String GRPC_SEND_SERVER_REQUEST =
-      JavaUtils.getClassSimpleName(GrpcService.class) + ".sendRequest";
+      JavaUtils.getClassSimpleName(GrpcServicesImpl.class) + ".sendRequest";
 
   class AsyncService implements RaftServerAsynchronousProtocol {
 
@@ -102,6 +105,7 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
 
   public static final class Builder {
     private RaftServer server;
+    private Customizer customizer;
 
     private String adminHost;
     private int adminPort;
@@ -150,6 +154,11 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
       return this;
     }
 
+    public Builder setCustomizer(Customizer customizer) {
+      this.customizer = customizer != null? customizer : 
Customizer.getDefaultInstance();
+      return this;
+    }
+
     private GrpcServerProtocolClient newGrpcServerProtocolClient(RaftPeer 
target) {
       return new GrpcServerProtocolClient(target, 
flowControlWindow.getSizeInt(),
           requestTimeoutDuration, serverTlsConfig, separateHeartbeatChannel);
@@ -177,6 +186,10 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
           JavaUtils.getClassSimpleName(getClass()) + "_" + serverPort);
     }
 
+    Server buildServer(NettyServerBuilder builder, EnumSet<GrpcServices.Type> 
types) {
+      return customizer.customize(builder, types).build();
+    }
+
     private NettyServerBuilder newNettyServerBuilderForServer() {
       return newNettyServerBuilder(serverHost, serverPort, serverTlsConfig);
     }
@@ -223,21 +236,24 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
     }
 
     Server newServer(GrpcClientProtocolService client, ZeroCopyMetrics 
zeroCopyMetrics, ServerInterceptor interceptor) {
+      final EnumSet<GrpcServices.Type> types = 
EnumSet.of(GrpcServices.Type.SERVER);
       final NettyServerBuilder serverBuilder = 
newNettyServerBuilderForServer();
       final ServerServiceDefinition service = 
newGrpcServerProtocolService(zeroCopyMetrics).bindServiceWithZeroCopy();
       serverBuilder.addService(ServerInterceptors.intercept(service, 
interceptor));
 
       if (!separateAdminServer()) {
+        types.add(GrpcServices.Type.ADMIN);
         addAdminService(serverBuilder, server, interceptor);
       }
       if (!separateClientServer()) {
+        types.add(GrpcServices.Type.CLIENT);
         addClientService(serverBuilder, client, interceptor);
       }
-      return serverBuilder.build();
+      return buildServer(serverBuilder, types);
     }
 
-    public GrpcService build() {
-      return new GrpcService(this);
+    public GrpcServicesImpl build() {
+      return new GrpcServicesImpl(this);
     }
 
     public Builder setAdminTlsConfig(GrpcTlsConfig config) {
@@ -273,11 +289,7 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
   private final MetricServerInterceptor serverInterceptor;
   private final ZeroCopyMetrics zeroCopyMetrics = new ZeroCopyMetrics();
 
-  public MetricServerInterceptor getServerInterceptor() {
-    return serverInterceptor;
-  }
-
-  private GrpcService(Builder b) {
+  private GrpcServicesImpl(Builder b) {
     super(b.server::getId, id -> new PeerProxyMap<>(id.toString(), 
b::newGrpcServerProtocolClient));
 
     this.executor = b.newExecutor();
@@ -291,7 +303,7 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
     if (b.separateAdminServer()) {
       final NettyServerBuilder builder = b.newNettyServerBuilderForAdmin();
       addAdminService(builder, b.server, serverInterceptor);
-      final Server adminServer = builder.build();
+      final Server adminServer = b.buildServer(builder, 
EnumSet.of(GrpcServices.Type.ADMIN));
       servers.put(GrpcAdminProtocolService.class.getName(), adminServer);
       adminServerAddressSupplier = newAddressSupplier(b.adminPort, 
adminServer);
     } else {
@@ -301,7 +313,7 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
     if (b.separateClientServer()) {
       final NettyServerBuilder builder = b.newNettyServerBuilderForClient();
       addClientService(builder, clientProtocolService, serverInterceptor);
-      final Server clientServer = builder.build();
+      final Server clientServer = b.buildServer(builder, 
EnumSet.of(GrpcServices.Type.CLIENT));
       servers.put(GrpcClientProtocolService.class.getName(), clientServer);
       clientServerAddressSupplier = newAddressSupplier(b.clientPort, 
clientServer);
     } else {
@@ -419,6 +431,11 @@ public final class GrpcService extends 
RaftServerRpcWithProxy<GrpcServerProtocol
     return getProxies().getProxy(target).startLeaderElection(request);
   }
 
+  @VisibleForTesting
+  MessageMetrics getMessageMetrics() {
+    return serverInterceptor.getMetrics();
+  }
+
   @VisibleForTesting
   public ZeroCopyMetrics getZeroCopyMetrics() {
     return zeroCopyMetrics;
diff --git 
a/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java 
b/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java
index fe12e29f1..1519298f3 100644
--- 
a/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java
+++ 
b/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java
@@ -22,7 +22,7 @@ import org.apache.ratis.RaftTestUtil;
 import org.apache.ratis.conf.Parameters;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
-import org.apache.ratis.grpc.server.GrpcService;
+import org.apache.ratis.grpc.server.GrpcServicesImpl;
 import org.apache.ratis.protocol.RaftGroup;
 import org.apache.ratis.protocol.RaftPeer;
 import org.apache.ratis.protocol.RaftPeerId;
@@ -63,7 +63,7 @@ public class MiniRaftClusterWithGrpc extends 
MiniRaftCluster.RpcBase {
   }
 
   public static final DelayLocalExecutionInjection 
SEND_SERVER_REQUEST_INJECTION =
-      new DelayLocalExecutionInjection(GrpcService.GRPC_SEND_SERVER_REQUEST);
+      new 
DelayLocalExecutionInjection(GrpcServicesImpl.GRPC_SEND_SERVER_REQUEST);
 
   public MiniRaftClusterWithGrpc(String[] ids, RaftProperties properties, 
Parameters parameters) {
     this(ids, new String[0], properties, parameters);
@@ -102,7 +102,7 @@ public class MiniRaftClusterWithGrpc extends 
MiniRaftCluster.RpcBase {
     getServers().forEach(server -> server.getGroupIds().forEach(id -> {
       LOG.info("Checking {}-{}", server.getId(), id);
       RaftServer.Division division = RaftServerTestUtil.getDivision(server, 
id);
-      GrpcService service = (GrpcService) 
RaftServerTestUtil.getServerRpc(division);
+      final GrpcServicesImpl service = (GrpcServicesImpl) 
RaftServerTestUtil.getServerRpc(division);
       ZeroCopyMetrics zeroCopyMetrics = service.getZeroCopyMetrics();
       Assert.assertEquals(0, zeroCopyMetrics.nonZeroCopyMessages());
       Assert.assertEquals("Zero copy messages are not released, please check 
logs to find leaks. ",
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/grpc/TestCustomGrpcServices.java 
b/ratis-test/src/test/java/org/apache/ratis/grpc/TestCustomGrpcServices.java
new file mode 100644
index 000000000..13c4a59fb
--- /dev/null
+++ b/ratis-test/src/test/java/org/apache/ratis/grpc/TestCustomGrpcServices.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc;
+
+import org.apache.ratis.BaseTest;
+import org.apache.ratis.RaftTestUtil;
+import org.apache.ratis.client.RaftClient;
+import org.apache.ratis.conf.Parameters;
+import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.grpc.server.GrpcServices;
+import org.apache.ratis.protocol.RaftClientReply;
+import org.apache.ratis.server.RaftServerRpc;
+import org.apache.ratis.test.proto.GreeterGrpc;
+import org.apache.ratis.test.proto.HelloReply;
+import org.apache.ratis.test.proto.HelloRequest;
+import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
+import org.apache.ratis.thirdparty.io.grpc.ManagedChannelBuilder;
+import org.apache.ratis.thirdparty.io.grpc.netty.NettyServerBuilder;
+import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
+import org.apache.ratis.util.IOUtils;
+import org.apache.ratis.util.NetUtils;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.EnumSet;
+import java.util.Objects;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.TimeUnit;
+
+import static org.apache.ratis.RaftTestUtil.waitForLeader;
+
+public class TestCustomGrpcServices extends BaseTest {
+  /** Add two different greeter services for client and admin. */
+  class MyCustomizer implements GrpcServices.Customizer {
+    final GreeterImpl clientGreeter = new GreeterImpl("Hello");
+    final GreeterImpl adminGreeter = new GreeterImpl("Hi");
+
+    @Override
+    public NettyServerBuilder customize(NettyServerBuilder builder, 
EnumSet<GrpcServices.Type> types) {
+      if (types.contains(GrpcServices.Type.CLIENT)) {
+        return builder.addService(clientGreeter);
+      }
+      if (types.contains(GrpcServices.Type.ADMIN)) {
+        return builder.addService(adminGreeter);
+      }
+      return builder;
+    }
+  }
+
+  class GreeterImpl extends GreeterGrpc.GreeterImplBase {
+    private final String prefix;
+
+    GreeterImpl(String prefix) {
+      this.prefix = prefix;
+    }
+
+    String toReply(String request) {
+      return prefix + " " + request;
+    }
+
+    @Override
+    public StreamObserver<HelloRequest> hello(StreamObserver<HelloReply> 
responseObserver) {
+      return new StreamObserver<HelloRequest>() {
+        @Override
+        public void onNext(HelloRequest helloRequest) {
+          final String reply = toReply(helloRequest.getName());
+          
responseObserver.onNext(HelloReply.newBuilder().setMessage(reply).build());
+        }
+
+        @Override
+        public void onError(Throwable throwable) {
+          LOG.error("onError", throwable);
+        }
+
+        @Override
+        public void onCompleted() {
+          responseObserver.onCompleted();
+        }
+      };
+    }
+  }
+
+  class GreeterClient implements Closeable {
+    private final ManagedChannel channel;
+    private final StreamObserver<HelloRequest> requestHandler;
+    private final Queue<CompletableFuture<String>> replies = new 
ConcurrentLinkedQueue<>();
+
+    GreeterClient(int port) {
+      this.channel = ManagedChannelBuilder.forAddress(NetUtils.LOCALHOST, port)
+          .usePlaintext()
+          .build();
+
+      final StreamObserver<HelloReply> responseHandler = new 
StreamObserver<HelloReply>() {
+        @Override
+        public void onNext(HelloReply helloReply) {
+          Objects.requireNonNull(replies.poll(), "queue is empty")
+              .complete(helloReply.getMessage());
+        }
+
+        @Override
+        public void onError(Throwable throwable) {
+          LOG.info("onError", throwable);
+          completeExceptionally(throwable);
+        }
+
+        @Override
+        public void onCompleted() {
+          LOG.info("onCompleted");
+          completeExceptionally(new IllegalStateException("onCompleted"));
+        }
+
+        void completeExceptionally(Throwable throwable) {
+          replies.forEach(f -> f.completeExceptionally(throwable));
+          replies.clear();
+        }
+      };
+      this.requestHandler = 
GreeterGrpc.newStub(channel).hello(responseHandler);
+    }
+
+    @Override
+    public void close() throws IOException {
+      try {
+        /* After the request handler is cancelled, no more life-cycle hooks 
are allowed,
+         * see {@link 
org.apache.ratis.thirdparty.io.grpc.ClientCall.Listener#cancel(String, 
Throwable)} */
+        // requestHandler.onCompleted();
+        channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
+      } catch (InterruptedException e) {
+        Thread.currentThread().interrupt();
+        throw IOUtils.toInterruptedIOException("Failed to close", e);
+      }
+    }
+
+    CompletableFuture<String> send(String name) {
+      LOG.info("send: {}", name);
+      final HelloRequest request = 
HelloRequest.newBuilder().setName(name).build();
+      final CompletableFuture<String> f = new CompletableFuture<>();
+      try {
+        requestHandler.onNext(request);
+        replies.offer(f);
+      } catch (IllegalStateException e) {
+        // already closed
+        f.completeExceptionally(e);
+      }
+      return f.whenComplete((r, e) -> LOG.info("reply: {}", r));
+    }
+  }
+
+  @Test
+  public void testCustomServices() throws Exception {
+    final String[] ids = {"s0"};
+    final RaftProperties properties = new RaftProperties();
+
+    final Parameters parameters = new Parameters();
+    final MyCustomizer customizer = new MyCustomizer();
+    GrpcConfigKeys.Server.setServicesCustomizer(parameters, customizer);
+
+    try(MiniRaftClusterWithGrpc cluster = new MiniRaftClusterWithGrpc(ids, 
properties, parameters)) {
+      cluster.start();
+      final RaftServerRpc server = 
waitForLeader(cluster).getRaftServer().getServerRpc();
+
+      // test Raft service
+      try (RaftClient client = cluster.createClient()) {
+        final RaftClientReply reply = client.io().send(new 
RaftTestUtil.SimpleMessage("abc"));
+        Assertions.assertTrue(reply.isSuccess());
+      }
+
+      // test custom client service
+      final int clientPort = server.getClientServerAddress().getPort();
+      try (GreeterClient client = new GreeterClient(clientPort)) {
+        sendAndAssertReply("world", client, customizer.clientGreeter);
+      }
+
+      // test custom admin service
+      final int adminPort = server.getAdminServerAddress().getPort();
+      try (GreeterClient admin = new GreeterClient(adminPort)) {
+        sendAndAssertReply("admin", admin, customizer.adminGreeter);
+      }
+    }
+  }
+
+  static void sendAndAssertReply(String name, GreeterClient client, 
GreeterImpl greeter) {
+    final String computed = client.send(name).join();
+    final String expected = greeter.toReply(name);
+    Assertions.assertEquals(expected, computed);
+  }
+}
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/grpc/TestGrpcMessageMetrics.java 
b/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcMessageMetrics.java
similarity index 91%
rename from 
ratis-test/src/test/java/org/apache/ratis/grpc/TestGrpcMessageMetrics.java
rename to 
ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcMessageMetrics.java
index 812c691e2..8094069cf 100644
--- a/ratis-test/src/test/java/org/apache/ratis/grpc/TestGrpcMessageMetrics.java
+++ 
b/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcMessageMetrics.java
@@ -15,13 +15,13 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.ratis.grpc;
+package org.apache.ratis.grpc.server;
 
 import org.apache.ratis.BaseTest;
+import org.apache.ratis.grpc.MiniRaftClusterWithGrpc;
 import org.apache.ratis.server.impl.MiniRaftCluster;
 import org.apache.ratis.RaftTestUtil;
 import org.apache.ratis.client.RaftClient;
-import org.apache.ratis.grpc.server.GrpcService;
 import org.apache.ratis.metrics.impl.JvmMetrics;
 import org.apache.ratis.metrics.RatisMetricRegistry;
 import org.apache.ratis.protocol.RaftClientReply;
@@ -66,8 +66,8 @@ public class TestGrpcMessageMetrics extends BaseTest
 
   static void assertMessageCount(RaftServer.Division server) {
     String serverId = server.getId().toString();
-    GrpcService service = (GrpcService) 
RaftServerTestUtil.getServerRpc(server);
-    RatisMetricRegistry registry = 
service.getServerInterceptor().getMetrics().getRegistry();
+    final GrpcServicesImpl services = (GrpcServicesImpl) 
RaftServerTestUtil.getServerRpc(server);
+    final RatisMetricRegistry registry = 
services.getMessageMetrics().getRegistry();
     String counter_prefix = serverId + "_" + 
"ratis.grpc.RaftServerProtocolService";
     Assertions.assertTrue(
         registry.counter(counter_prefix + "_" + "requestVote" + 
"_OK_completed_total").getCount() > 0);

Reply via email to